Skip to content

Commit 5d739cf

Browse files
authored
Create function declaration in the proper module (#161281)
Using `memref.dealloc` in the gpu module would add a function definition for `@free` in the the top level module instead of the gpu module. The fix is to do what is already done for memref.alloc which is to use `op->getParentWithTrait<OpTrait::SymbolTable>()` instead of `op->getParentOfType<ModuleOp>()` to create the call in the proper module.
1 parent d28c07b commit 5d739cf

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
4848
}
4949

5050
static FailureOr<LLVM::LLVMFuncOp>
51-
getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
52-
SymbolTableCollection *symbolTables) {
51+
getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52+
Operation *module, SymbolTableCollection *symbolTables) {
5353
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
5454

5555
if (useGenericFn)
@@ -483,8 +483,8 @@ class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
483483
ConversionPatternRewriter &rewriter) const override {
484484
// Insert the `free` declaration if it is not already present.
485485
FailureOr<LLVM::LLVMFuncOp> freeFunc =
486-
getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
487-
symbolTables);
486+
getFreeFn(rewriter, getTypeConverter(),
487+
op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
488488
if (failed(freeFunc))
489489
return failure();
490490
Value allocatedPtr;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt --convert-to-llvm %s | FileCheck %s
2+
3+
// Checking that malloc and free are declared in the proper module.
4+
5+
// CHECK: module attributes {gpu.container_module} {
6+
// CHECK: llvm.func @free(!llvm.ptr)
7+
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
8+
// CHECK: gpu.module @kernels {
9+
// CHECK: llvm.func @free(!llvm.ptr)
10+
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
11+
// CHECK: gpu.func @kernel_1
12+
// CHECK: llvm.call @malloc({{.*}}) : (i64) -> !llvm.ptr
13+
// CHECK: llvm.call @free({{.*}}) : (!llvm.ptr) -> ()
14+
// CHECK: gpu.return
15+
// CHECK: }
16+
// CHECK: }
17+
// CHECK: }
18+
module attributes {gpu.container_module} {
19+
20+
gpu.module @kernels {
21+
gpu.func @kernel_1() kernel {
22+
%memref_a = memref.alloc() : memref<8x16xf32>
23+
memref.dealloc %memref_a : memref<8x16xf32>
24+
gpu.return
25+
}
26+
}
27+
28+
func.func @main() {
29+
%memref_a = memref.alloc() : memref<8x16xf32>
30+
memref.dealloc %memref_a : memref<8x16xf32>
31+
return
32+
}
33+
}

0 commit comments

Comments
 (0)