Skip to content

Commit 08df086

Browse files
committed
Simplify ctx usage
1 parent 76e0712 commit 08df086

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

python/examples/mlir/memref_management.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,33 @@
1212
import lighthouse.utils as lh_utils
1313

1414

15-
def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module:
16-
with ctx, ir.Location.unknown():
17-
module = ir.Module.create()
18-
with ir.InsertionPoint(module.body):
19-
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())
20-
21-
# Return a new buffer initialized with input's data.
22-
@func.func(mem_type)
23-
def copy(input):
24-
new_buf = memref.alloc(mem_type, [], [])
25-
memref.copy(input, new_buf)
26-
return new_buf
27-
28-
# Free given buffer.
29-
@func.func(mem_type)
30-
def module_dealloc(input):
31-
memref.dealloc(input)
15+
def create_mlir_module(shape: list[int]) -> ir.Module:
16+
module = ir.Module.create()
17+
with ir.InsertionPoint(module.body):
18+
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())
19+
20+
# Return a new buffer initialized with input's data.
21+
@func.func(mem_type)
22+
def copy(input):
23+
new_buf = memref.alloc(mem_type, [], [])
24+
memref.copy(input, new_buf)
25+
return new_buf
26+
27+
# Free given buffer.
28+
@func.func(mem_type)
29+
def module_dealloc(input):
30+
memref.dealloc(input)
3231

3332
return module
3433

3534

3635
def lower_to_llvm(operation: ir.Operation) -> None:
37-
with operation.context:
38-
pm = PassManager("builtin.module")
39-
pm.add("func.func(llvm-request-c-wrappers)")
40-
pm.add("convert-to-llvm")
41-
pm.add("reconcile-unrealized-casts")
42-
pm.add("cse")
43-
pm.add("canonicalize")
36+
pm = PassManager("builtin.module")
37+
pm.add("func.func(llvm-request-c-wrappers)")
38+
pm.add("convert-to-llvm")
39+
pm.add("reconcile-unrealized-casts")
40+
pm.add("cse")
41+
pm.add("canonicalize")
4442
pm.run(operation)
4543

4644

@@ -75,8 +73,7 @@ def main():
7573
shape = [16, 32]
7674

7775
# Create and compile test module.
78-
ctx = ir.Context()
79-
kernel = create_mlir_module(ctx, shape)
76+
kernel = create_mlir_module(shape)
8077
lower_to_llvm(kernel.operation)
8178
eng = ExecutionEngine(kernel, opt_level=3)
8279
eng.initialize()
@@ -116,4 +113,5 @@ def main():
116113

117114

118115
if __name__ == "__main__":
119-
main()
116+
with ir.Context(), ir.Location.unknown():
117+
main()

0 commit comments

Comments
 (0)