Skip to content

Commit 609899f

Browse files
authored
[flang][cuda] Avoid stack corruption when setting kernel launch parameters (#119469)
In order to get the pointer to a structure member, `getelementptr` typically requires two indices: one to indicate the structure itself, and another to specify the member's position. We are missing the former in `GPULaunchKernelConversion`, so generated code may cause stack corruption. This PR corrects the indices of a structure used as a kernel launch temp.
1 parent 377d1f0 commit 609899f

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
4242
auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes);
4343
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
4444
mlir::Type i32Ty = rewriter.getI32Type();
45+
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
46+
loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
4547
auto one = rewriter.create<mlir::LLVM::ConstantOp>(
4648
loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1));
4749
mlir::Value argStruct =
@@ -55,7 +57,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
5557
auto indice = rewriter.create<mlir::LLVM::ConstantOp>(
5658
loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i));
5759
mlir::Value structMember = rewriter.create<LLVM::GEPOp>(
58-
loc, ptrTy, structTy, argStruct, mlir::ArrayRef<mlir::Value>({indice}));
60+
loc, ptrTy, structTy, argStruct,
61+
mlir::ArrayRef<mlir::Value>({zero, indice}));
5962
rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
6063
mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
6164
loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));

flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
102102
// CHECK: %[[STRUCT:.*]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr)> : (i32) -> !llvm.ptr
103103
// CHECK: %[[PARAMS:.*]] = llvm.alloca %{{.*}} x !llvm.ptr : (i32) -> !llvm.ptr
104104
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
105-
// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
105+
// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%{{.*}}, {{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
106106
// CHECK: llvm.store %{{.*}}, %[[STRUCT_PTR]] : !llvm.ptr, !llvm.ptr
107107
// CHECK: %[[PARAM_PTR:.*]] = llvm.getelementptr %[[PARAMS]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
108108
// CHECK: llvm.store %[[STRUCT_PTR]], %[[PARAM_PTR]] : !llvm.ptr, !llvm.ptr

0 commit comments

Comments
 (0)