Skip to content

Commit fc56a3c

Browse files
committed
[NVTPX] Copy kernel arguments as byte array
Ensures that struct padding is not skipped, as it may contain actual data if the struct is really a union. Fixes #53710
1 parent 7a4b320 commit fc56a3c

File tree

3 files changed

+62
-35
lines changed

3 files changed

+62
-35
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,17 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
626626
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
627627
// addrspacecast preserves alignment. Since params are constant, this load
628628
// is definitely not volatile.
629+
const auto StructBytes = *AllocA->getAllocationSize(DL);
630+
const auto ChunkBytes = (StructBytes % 8 == 0) ? 8 :
631+
(StructBytes % 4 == 0) ? 4 :
632+
(StructBytes % 2 == 0) ? 2 : 1;
633+
Type *ChunkType = Type::getIntNTy(Func->getContext(), 8 * ChunkBytes);
634+
Type *OpaqueType = ArrayType::get(ChunkType, StructBytes / ChunkBytes);
629635
LoadInst *LI =
630-
new LoadInst(StructType, ArgInParam, Arg->getName(),
636+
new LoadInst(OpaqueType, ArgInParam, Arg->getName(),
631637
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
632-
new StoreInst(LI, AllocA, FirstInst);
638+
new StoreInst(LI, AllocA,
639+
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
633640
}
634641
}
635642

llvm/test/CodeGen/NVPTX/lower-args.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ target triple = "nvptx64-nvidia-cuda"
1414
; COMMON-LABEL: load_alignment
1515
define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %arg) {
1616
entry:
17-
; IR: load %class.outer, ptr addrspace(101)
17+
; IR: load [3 x i64], ptr addrspace(101)
1818
; IR-SAME: align 8
1919
; PTX: ld.param.u64
2020
; PTX-NOT: ld.param.u8

llvm/test/CodeGen/NVPTX/lower-byval-args.ll

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ define dso_local void @read_only_gep_asc0(ptr nocapture noundef writeonly %out,
8888
; COMMON-NEXT: [[ENTRY:.*:]]
8989
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
9090
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
91-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
92-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
91+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
92+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
9393
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
9494
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
9595
; COMMON-NEXT: [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -115,8 +115,8 @@ define dso_local void @escape_ptr(ptr nocapture noundef readnone %out, ptr nound
115115
; COMMON-NEXT: [[ENTRY:.*:]]
116116
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
117117
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
118-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
119-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
118+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
119+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
120120
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
121121
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
122122
; COMMON-NEXT: call void @_Z6escapePv(ptr noundef nonnull [[S3]])
@@ -134,8 +134,8 @@ define dso_local void @escape_ptr_gep(ptr nocapture noundef readnone %out, ptr n
134134
; COMMON-NEXT: [[ENTRY:.*:]]
135135
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
136136
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
137-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
138-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
137+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
138+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
139139
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
140140
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
141141
; COMMON-NEXT: [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -155,8 +155,8 @@ define dso_local void @escape_ptr_store(ptr nocapture noundef writeonly %out, pt
155155
; COMMON-NEXT: [[ENTRY:.*:]]
156156
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
157157
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
158-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
159-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
158+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
159+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
160160
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
161161
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
162162
; COMMON-NEXT: store ptr [[S3]], ptr [[OUT2]], align 8
@@ -174,8 +174,8 @@ define dso_local void @escape_ptr_gep_store(ptr nocapture noundef writeonly %out
174174
; COMMON-NEXT: [[ENTRY:.*:]]
175175
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
176176
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
177-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
178-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
177+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
178+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
179179
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
180180
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
181181
; COMMON-NEXT: [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -195,8 +195,8 @@ define dso_local void @escape_ptrtoint(ptr nocapture noundef writeonly %out, ptr
195195
; COMMON-NEXT: [[ENTRY:.*:]]
196196
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
197197
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
198-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
199-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
198+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
199+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
200200
; COMMON-NEXT: [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
201201
; COMMON-NEXT: [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
202202
; COMMON-NEXT: [[I:%.*]] = ptrtoint ptr [[S3]] to i64
@@ -232,8 +232,8 @@ define dso_local void @memcpy_to_param(ptr nocapture noundef readonly %in, ptr n
232232
; COMMON-NEXT: [[ENTRY:.*:]]
233233
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
234234
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
235-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
236-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
235+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
236+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
237237
; COMMON-NEXT: [[IN1:%.*]] = addrspacecast ptr [[IN]] to ptr addrspace(1)
238238
; COMMON-NEXT: [[IN2:%.*]] = addrspacecast ptr addrspace(1) [[IN1]] to ptr
239239
; COMMON-NEXT: tail call void @llvm.memcpy.p0.p0.i64(ptr [[S3]], ptr [[IN2]], i64 16, i1 true)
@@ -251,8 +251,8 @@ define dso_local void @copy_on_store(ptr nocapture noundef readonly %in, ptr noc
251251
; COMMON-NEXT: [[BB:.*:]]
252252
; COMMON-NEXT: [[S3:%.*]] = alloca [[STRUCT_S]], align 4
253253
; COMMON-NEXT: [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
254-
; COMMON-NEXT: [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
255-
; COMMON-NEXT: store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
254+
; COMMON-NEXT: [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
255+
; COMMON-NEXT: store [1 x i64] [[S5]], ptr [[S3]], align 4
256256
; COMMON-NEXT: [[IN1:%.*]] = addrspacecast ptr [[IN]] to ptr addrspace(1)
257257
; COMMON-NEXT: [[IN2:%.*]] = addrspacecast ptr addrspace(1) [[IN1]] to ptr
258258
; COMMON-NEXT: [[I:%.*]] = load i32, ptr [[IN2]], align 4
@@ -273,12 +273,12 @@ define void @test_select(ptr byval(i32) align 4 %input1, ptr byval(i32) %input2,
273273
; SM_60-NEXT: [[OUT8:%.*]] = addrspacecast ptr addrspace(1) [[OUT7]] to ptr
274274
; SM_60-NEXT: [[INPUT24:%.*]] = alloca i32, align 4
275275
; SM_60-NEXT: [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
276-
; SM_60-NEXT: [[INPUT26:%.*]] = load i32, ptr addrspace(101) [[INPUT25]], align 4
277-
; SM_60-NEXT: store i32 [[INPUT26]], ptr [[INPUT24]], align 4
276+
; SM_60-NEXT: [[INPUT26:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT25]], align 4
277+
; SM_60-NEXT: store [1 x i32] [[INPUT26]], ptr [[INPUT24]], align 4
278278
; SM_60-NEXT: [[INPUT11:%.*]] = alloca i32, align 4
279279
; SM_60-NEXT: [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
280-
; SM_60-NEXT: [[INPUT13:%.*]] = load i32, ptr addrspace(101) [[INPUT12]], align 4
281-
; SM_60-NEXT: store i32 [[INPUT13]], ptr [[INPUT11]], align 4
280+
; SM_60-NEXT: [[INPUT13:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT12]], align 4
281+
; SM_60-NEXT: store [1 x i32] [[INPUT13]], ptr [[INPUT11]], align 4
282282
; SM_60-NEXT: [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT24]]
283283
; SM_60-NEXT: [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
284284
; SM_60-NEXT: store i32 [[VALLOADED]], ptr [[OUT8]], align 4
@@ -313,12 +313,12 @@ define void @test_select_write(ptr byval(i32) align 4 %input1, ptr byval(i32) %i
313313
; COMMON-NEXT: [[OUT8:%.*]] = addrspacecast ptr addrspace(1) [[OUT7]] to ptr
314314
; COMMON-NEXT: [[INPUT24:%.*]] = alloca i32, align 4
315315
; COMMON-NEXT: [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
316-
; COMMON-NEXT: [[INPUT26:%.*]] = load i32, ptr addrspace(101) [[INPUT25]], align 4
317-
; COMMON-NEXT: store i32 [[INPUT26]], ptr [[INPUT24]], align 4
316+
; COMMON-NEXT: [[INPUT26:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT25]], align 4
317+
; COMMON-NEXT: store [1 x i32] [[INPUT26]], ptr [[INPUT24]], align 4
318318
; COMMON-NEXT: [[INPUT11:%.*]] = alloca i32, align 4
319319
; COMMON-NEXT: [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
320-
; COMMON-NEXT: [[INPUT13:%.*]] = load i32, ptr addrspace(101) [[INPUT12]], align 4
321-
; COMMON-NEXT: store i32 [[INPUT13]], ptr [[INPUT11]], align 4
320+
; COMMON-NEXT: [[INPUT13:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT12]], align 4
321+
; COMMON-NEXT: store [1 x i32] [[INPUT13]], ptr [[INPUT11]], align 4
322322
; COMMON-NEXT: [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT24]]
323323
; COMMON-NEXT: store i32 1, ptr [[PTRNEW]], align 4
324324
; COMMON-NEXT: ret void
@@ -337,12 +337,12 @@ define void @test_phi(ptr byval(%struct.S) align 4 %input1, ptr byval(%struct.S)
337337
; SM_60-NEXT: [[INOUT8:%.*]] = addrspacecast ptr addrspace(1) [[INOUT7]] to ptr
338338
; SM_60-NEXT: [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
339339
; SM_60-NEXT: [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
340-
; SM_60-NEXT: [[INPUT26:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT25]], align 8
341-
; SM_60-NEXT: store [[STRUCT_S]] [[INPUT26]], ptr [[INPUT24]], align 4
340+
; SM_60-NEXT: [[INPUT26:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT25]], align 8
341+
; SM_60-NEXT: store [1 x i64] [[INPUT26]], ptr [[INPUT24]], align 8
342342
; SM_60-NEXT: [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
343343
; SM_60-NEXT: [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
344-
; SM_60-NEXT: [[INPUT13:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT12]], align 4
345-
; SM_60-NEXT: store [[STRUCT_S]] [[INPUT13]], ptr [[INPUT11]], align 4
344+
; SM_60-NEXT: [[INPUT13:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT12]], align 4
345+
; SM_60-NEXT: store [1 x i64] [[INPUT13]], ptr [[INPUT11]], align 4
346346
; SM_60-NEXT: br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
347347
; SM_60: [[FIRST]]:
348348
; SM_60-NEXT: [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
@@ -402,12 +402,12 @@ define void @test_phi_write(ptr byval(%struct.S) align 4 %input1, ptr byval(%str
402402
; COMMON-NEXT: [[BB:.*:]]
403403
; COMMON-NEXT: [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
404404
; COMMON-NEXT: [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
405-
; COMMON-NEXT: [[INPUT26:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT25]], align 8
406-
; COMMON-NEXT: store [[STRUCT_S]] [[INPUT26]], ptr [[INPUT24]], align 4
405+
; COMMON-NEXT: [[INPUT26:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT25]], align 8
406+
; COMMON-NEXT: store [1 x i64] [[INPUT26]], ptr [[INPUT24]], align 8
407407
; COMMON-NEXT: [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
408408
; COMMON-NEXT: [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
409-
; COMMON-NEXT: [[INPUT13:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT12]], align 4
410-
; COMMON-NEXT: store [[STRUCT_S]] [[INPUT13]], ptr [[INPUT11]], align 4
409+
; COMMON-NEXT: [[INPUT13:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT12]], align 4
410+
; COMMON-NEXT: store [1 x i64] [[INPUT13]], ptr [[INPUT11]], align 4
411411
; COMMON-NEXT: br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
412412
; COMMON: [[FIRST]]:
413413
; COMMON-NEXT: [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
@@ -437,6 +437,26 @@ merge: ; preds = %second, %first
437437
ret void
438438
}
439439

440+
%union.U = type { %struct.P }
441+
%struct.P = type { i8, i32 }
442+
443+
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
444+
define dso_local void @padding(ptr nocapture noundef readnone %out, ptr noundef byval(%union.U) align 4 %s) local_unnamed_addr #0 {
445+
; COMMON-LABEL: define dso_local void @padding(
446+
; COMMON-SAME: ptr nocapture noundef readnone [[OUT:%.*]], ptr noundef byval([[UNION_U:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
447+
; COMMON-NEXT: [[ENTRY:.*:]]
448+
; COMMON-NEXT: [[S1:%.*]] = alloca [[UNION_U]], align 4
449+
; COMMON-NEXT: [[S2:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
450+
; COMMON-NEXT: [[S3:%.*]] = load [1 x i64], ptr addrspace(101) [[S2]], align 4
451+
; COMMON-NEXT: store [1 x i64] [[S3]], ptr [[S1]], align 4
452+
; COMMON-NEXT: call void @_Z6escapePv(ptr noundef nonnull [[S1]])
453+
; COMMON-NEXT: ret void
454+
;
455+
entry:
456+
call void @_Z6escapePv(ptr noundef nonnull %s) #0
457+
ret void
458+
}
459+
440460
attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite) "no-trapping-math"="true" "target-cpu"="sm_60" "target-features"="+ptx78,+sm_60" "uniform-work-group-size"="true" }
441461
attributes #1 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }
442462
attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: write) }

0 commit comments

Comments
 (0)