-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
83 lines (60 loc) · 2.31 KB
/
main.py
File metadata and controls
83 lines (60 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pathlib import Path
from egglog import *
from xdsl.context import MLContext
from xdsl.dialects import (
arith,
func,
linalg,
memref,
printf,
scf,
tensor,
)
from xdsl.dialects.builtin import Builtin
from xdsl.parser import Parser as IRParser
from xdsl.printer import Printer
from converter import Converter
from eggie.rewrites import rewrites_ruleset
def context() -> MLContext:
ctx = MLContext()
ctx.load_dialect(arith.Arith)
ctx.load_dialect(Builtin)
ctx.load_dialect(func.Func)
ctx.load_dialect(linalg.Linalg)
ctx.load_dialect(memref.MemRef)
ctx.load_dialect(printf.Printf)
ctx.load_dialect(scf.Scf)
ctx.load_dialect(tensor.Tensor)
return ctx
if __name__ == "__main__":
current_dir = Path(__file__).parent
data_dir = f"{current_dir}/data"
models_path = f"{data_dir}/mlir"
eggs_path = f"{data_dir}/eggs"
converted_path = f"{data_dir}/converted"
mlir_files = [f for f in Path(models_path).iterdir() if f.suffix == ".mlir"]
for mlir_file in mlir_files:
file_name = Path(mlir_file).stem
print(f"Processing mlir: {file_name}")
with open(mlir_file) as f:
mlir_parser = IRParser(context(), f.read(), name=f"{mlir_file}")
module_op = mlir_parser.parse_module()
egglog_region = Converter.to_egglog(module_op)
egg_file_name = f"{eggs_path}/{file_name}.egg"
with open(egg_file_name, "w") as f:
f.write(str(egglog_region))
egraph = EGraph(save_egglog_string=True)
egglog_region = egraph.let("expr", egglog_region)
egraph.run(10, ruleset=rewrites_ruleset)
# egraph.display()
print(f"Extracting expression")
extracted = egraph.extract(egglog_region)
converted_egg_file = f"{converted_path}/{file_name}.egg"
with open(converted_egg_file, "w") as f:
f.write(str(extracted))
converted_module_op = Converter.to_mlir(extracted, context())
converted_mlir_file = f"{converted_path}/{file_name}.mlir"
with open(converted_mlir_file, "w") as f:
printer = Printer(stream=f)
printer.print(converted_module_op)
# assert module_op.is_structurally_equivalent(converted_module_op)