|
12 | 12 | import lighthouse.utils as lh_utils |
13 | 13 |
|
14 | 14 |
|
15 | | -def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: |
16 | | - with ctx, ir.Location.unknown(): |
17 | | - module = ir.Module.create() |
18 | | - with ir.InsertionPoint(module.body): |
19 | | - mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) |
20 | | - |
21 | | - # Return a new buffer initialized with input's data. |
22 | | - @func.func(mem_type) |
23 | | - def copy(input): |
24 | | - new_buf = memref.alloc(mem_type, [], []) |
25 | | - memref.copy(input, new_buf) |
26 | | - return new_buf |
27 | | - |
28 | | - # Free given buffer. |
29 | | - @func.func(mem_type) |
30 | | - def module_dealloc(input): |
31 | | - memref.dealloc(input) |
| 15 | +def create_mlir_module(shape: list[int]) -> ir.Module: |
| 16 | + module = ir.Module.create() |
| 17 | + with ir.InsertionPoint(module.body): |
| 18 | + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) |
| 19 | + |
| 20 | + # Return a new buffer initialized with input's data. |
| 21 | + @func.func(mem_type) |
| 22 | + def copy(input): |
| 23 | + new_buf = memref.alloc(mem_type, [], []) |
| 24 | + memref.copy(input, new_buf) |
| 25 | + return new_buf |
| 26 | + |
| 27 | + # Free given buffer. |
| 28 | + @func.func(mem_type) |
| 29 | + def module_dealloc(input): |
| 30 | + memref.dealloc(input) |
32 | 31 |
|
33 | 32 | return module |
34 | 33 |
|
35 | 34 |
|
36 | 35 | def lower_to_llvm(operation: ir.Operation) -> None: |
37 | | - with operation.context: |
38 | | - pm = PassManager("builtin.module") |
39 | | - pm.add("func.func(llvm-request-c-wrappers)") |
40 | | - pm.add("convert-to-llvm") |
41 | | - pm.add("reconcile-unrealized-casts") |
42 | | - pm.add("cse") |
43 | | - pm.add("canonicalize") |
| 36 | + pm = PassManager("builtin.module") |
| 37 | + pm.add("func.func(llvm-request-c-wrappers)") |
| 38 | + pm.add("convert-to-llvm") |
| 39 | + pm.add("reconcile-unrealized-casts") |
| 40 | + pm.add("cse") |
| 41 | + pm.add("canonicalize") |
44 | 42 | pm.run(operation) |
45 | 43 |
|
46 | 44 |
|
@@ -75,8 +73,7 @@ def main(): |
75 | 73 | shape = [16, 32] |
76 | 74 |
|
77 | 75 | # Create and compile test module. |
78 | | - ctx = ir.Context() |
79 | | - kernel = create_mlir_module(ctx, shape) |
| 76 | + kernel = create_mlir_module(shape) |
80 | 77 | lower_to_llvm(kernel.operation) |
81 | 78 | eng = ExecutionEngine(kernel, opt_level=3) |
82 | 79 | eng.initialize() |
@@ -116,4 +113,5 @@ def main(): |
116 | 113 |
|
117 | 114 |
|
118 | 115 | if __name__ == "__main__": |
119 | | - main() |
| 116 | + with ir.Context(), ir.Location.unknown(): |
| 117 | + main() |
0 commit comments