Skip to content

Commit 00a2350

Browse files
committed
[flang][cuda] Fix GPULaunchKernelConversion to generate correct kernel launch parameters
1 parent 4d06623 commit 00a2350

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
5858
loc, ptrTy, structTy, argStruct, mlir::ArrayRef<mlir::Value>({indice}));
5959
rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
6060
mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
61-
loc, ptrTy, structTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
61+
loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
6262
rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember);
6363
}
6464
return argArray;

flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,16 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
9999
}
100100

101101
// CHECK-LABEL: _QMmod1Phost_sub
102-
102+
// CHECK: %[[STRUCT:.*]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr)> : (i32) -> !llvm.ptr
103+
// CHECK: %[[PARAMS:.*]] = llvm.alloca %{{.*}} x !llvm.ptr : (i32) -> !llvm.ptr
104+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
105+
// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
106+
// CHECK: llvm.store %{{.*}}, %[[STRUCT_PTR]] : !llvm.ptr, !llvm.ptr
107+
// CHECK: %[[PARAM_PTR:.*]] = llvm.getelementptr %[[PARAMS]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
108+
// CHECK: llvm.store %[[STRUCT_PTR]], %[[PARAM_PTR]] : !llvm.ptr, !llvm.ptr
103109
// CHECK: %[[KERNEL_PTR:.*]] = llvm.mlir.addressof @_QMmod1Psub1 : !llvm.ptr
104-
// CHECK: llvm.call @_FortranACUFLaunchKernel(%[[KERNEL_PTR]], {{.*}})
110+
// CHECK: %[[NULL:.*]] = llvm.mlir.zero : !llvm.ptr
111+
// CHECK: llvm.call @_FortranACUFLaunchKernel(%[[KERNEL_PTR]], {{.*}}, %[[PARAMS]], %[[NULL]])
105112

106113
// -----
107114

0 commit comments

Comments
 (0)