- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
Create function declaration in the proper module #161281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
          
 @llvm/pr-subscribers-mlir-gpu Author: Renaud Kauffmann (Renaud-K) ChangesUsing  Full diff: https://github.com/llvm/llvm-project/pull/161281.diff 2 Files Affected: 
 diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 262e0e7a30c63..cc6314cbd1ffe 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -48,8 +48,8 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
-          SymbolTableCollection *symbolTables) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+          Operation *module, SymbolTableCollection *symbolTables) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -483,8 +483,8 @@ class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
     FailureOr<LLVM::LLVMFuncOp> freeFunc =
-        getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
-                  symbolTables);
+        getFreeFn(rewriter, getTypeConverter(),
+                  op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
     if (failed(freeFunc))
       return failure();
     Value allocatedPtr;
diff --git a/mlir/test/Dialect/GPU/memref-to-llvm.mlir b/mlir/test/Dialect/GPU/memref-to-llvm.mlir
new file mode 100644
index 0000000000000..81a96bf29e84f
--- /dev/null
+++ b/mlir/test/Dialect/GPU/memref-to-llvm.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt --convert-to-llvm %s | FileCheck %s
+
+// Checking that malloc and free are declared in the proper module.
+
+// CHECK: module attributes {gpu.container_module} {
+// CHECK:   llvm.func @free(!llvm.ptr)
+// CHECK:   llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK:   gpu.module @kernels {
+// CHECK:     llvm.func @free(!llvm.ptr)
+// CHECK:     llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK:     gpu.func @kernel_1
+// CHECK:       llvm.call @malloc({{.*}}) : (i64) -> !llvm.ptr
+// CHECK:       llvm.call @free({{.*}}) : (!llvm.ptr) -> ()
+// CHECK:       gpu.return
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
+module attributes {gpu.container_module} {
+
+  gpu.module @kernels {
+    gpu.func @kernel_1() kernel {
+      %memref_a = memref.alloc() : memref<8x16xf32>
+      memref.dealloc %memref_a : memref<8x16xf32>
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %memref_a = memref.alloc() : memref<8x16xf32>
+    memref.dealloc %memref_a : memref<8x16xf32>
+    return
+  }
+}
 | 
    
| 
          
 @llvm/pr-subscribers-mlir Author: Renaud Kauffmann (Renaud-K) ChangesUsing  Full diff: https://github.com/llvm/llvm-project/pull/161281.diff 2 Files Affected: 
 diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 262e0e7a30c63..cc6314cbd1ffe 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -48,8 +48,8 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
-          SymbolTableCollection *symbolTables) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+          Operation *module, SymbolTableCollection *symbolTables) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -483,8 +483,8 @@ class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
     FailureOr<LLVM::LLVMFuncOp> freeFunc =
-        getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
-                  symbolTables);
+        getFreeFn(rewriter, getTypeConverter(),
+                  op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
     if (failed(freeFunc))
       return failure();
     Value allocatedPtr;
diff --git a/mlir/test/Dialect/GPU/memref-to-llvm.mlir b/mlir/test/Dialect/GPU/memref-to-llvm.mlir
new file mode 100644
index 0000000000000..81a96bf29e84f
--- /dev/null
+++ b/mlir/test/Dialect/GPU/memref-to-llvm.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt --convert-to-llvm %s | FileCheck %s
+
+// Checking that malloc and free are declared in the proper module.
+
+// CHECK: module attributes {gpu.container_module} {
+// CHECK:   llvm.func @free(!llvm.ptr)
+// CHECK:   llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK:   gpu.module @kernels {
+// CHECK:     llvm.func @free(!llvm.ptr)
+// CHECK:     llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK:     gpu.func @kernel_1
+// CHECK:       llvm.call @malloc({{.*}}) : (i64) -> !llvm.ptr
+// CHECK:       llvm.call @free({{.*}}) : (!llvm.ptr) -> ()
+// CHECK:       gpu.return
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
+module attributes {gpu.container_module} {
+
+  gpu.module @kernels {
+    gpu.func @kernel_1() kernel {
+      %memref_a = memref.alloc() : memref<8x16xf32>
+      memref.dealloc %memref_a : memref<8x16xf32>
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %memref_a = memref.alloc() : memref<8x16xf32>
+    memref.dealloc %memref_a : memref<8x16xf32>
+    return
+  }
+}
 | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Using `memref.dealloc` in the gpu module would add a function definition for `@free` in the the top level module instead of the gpu module. The fix is to do what is already done for memref.alloc which is to use `op->getParentWithTrait<OpTrait::SymbolTable>()` instead of `op->getParentOfType<ModuleOp>()` to create the call in the proper module.
Using
memref.deallocin the gpu module would add a function definition for@freein the the top level module instead of the gpu module. The fix is to do what is already done for memref.alloc which is to useop->getParentWithTrait<OpTrait::SymbolTable>()instead ofop->getParentOfType<ModuleOp>()to create the call in the proper module.