Skip to content

Commit 611d5ae

Browse files
committed
generate passes from upstream td
1 parent e13403a commit 611d5ae

File tree

6 files changed

+3128
-1034
lines changed

6 files changed

+3128
-1034
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import glob
2+
import json
3+
import keyword
4+
import platform
5+
import shutil
6+
import subprocess
7+
import sys
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
from subprocess import PIPE
11+
from textwrap import dedent, indent
12+
13+
from mlir._mlir_libs import include
14+
15+
include_path = Path(include.__path__[0])
16+
17+
18+
def dump_json(td_path: Path):
19+
llvm_tblgen_name = "llvm-tblgen"
20+
if platform.system() == "Windows":
21+
llvm_tblgen_name += ".exe"
22+
23+
# try from mlir-native-tools
24+
llvm_tblgen_path = Path(sys.prefix) / "bin" / llvm_tblgen_name
25+
# try to find using which
26+
if not llvm_tblgen_path.exists():
27+
llvm_tblgen_path = shutil.which(llvm_tblgen_name)
28+
assert Path(llvm_tblgen_path).exists() is not None, "couldn't find llvm-tblgen"
29+
30+
args = [f"-I{include_path}", f"-I{td_path.parent}", str(td_path), "-dump-json"]
31+
res = subprocess.run(
32+
[llvm_tblgen_path] + args,
33+
cwd=Path(".").cwd(),
34+
check=True,
35+
stdout=PIPE,
36+
stderr=subprocess.DEVNULL,
37+
)
38+
res_json = json.loads(res.stdout.decode("utf-8"))
39+
40+
return res_json
41+
42+
43+
@dataclass
44+
class Option:
45+
argument: str
46+
description: str
47+
type: str
48+
additional_opt_flags: str
49+
default_value: str
50+
list_option: bool = False
51+
52+
53+
@dataclass
54+
class Pass:
55+
name: str
56+
argument: str
57+
options: list[Option]
58+
description: str
59+
summary: str
60+
61+
62+
TYPE_MAP = {
63+
"::mlir::gpu::amd::Runtime": '"gpu::amd::Runtime"',
64+
"OpPassManager": '"OpPassManager"',
65+
"bool": "bool",
66+
"double": "float",
67+
"enum FusionMode": '"FusionMode"',
68+
"int": "int",
69+
"int32_t": "int",
70+
"int64_t": "int",
71+
"mlir::SparseParallelizationStrategy": '"SparseParallelizationStrategy"',
72+
"mlir::arm_sme::ArmStreaming": '"arm_sme::ArmStreaming"',
73+
"std::string": "str",
74+
"uint64_t": "int",
75+
"unsigned": "int",
76+
}
77+
78+
79+
def generate_pass_method(pass_: Pass):
80+
ident = 4
81+
py_args = []
82+
for o in pass_.options:
83+
argument = o.argument.replace("-", "_")
84+
if keyword.iskeyword(argument):
85+
argument += "_"
86+
type = TYPE_MAP[o.type]
87+
if o.list_option:
88+
type = f"list[{type}]"
89+
py_args.append((argument, type))
90+
91+
def print_options_doc_string(pass_):
92+
print(
93+
indent(
94+
f"'''{pass_.summary}",
95+
prefix=" " * ident * 2,
96+
)
97+
)
98+
if pass_.description:
99+
for l in pass_.description.split("\n"):
100+
print(
101+
indent(
102+
f"{l}",
103+
prefix=" " * ident,
104+
)
105+
)
106+
if pass_.options:
107+
print(
108+
indent(
109+
f"Args:",
110+
prefix=" " * ident * 2,
111+
)
112+
)
113+
for o in pass_.options:
114+
print(
115+
indent(
116+
f"{o.argument}: {o.description}",
117+
prefix=" " * ident * 3,
118+
)
119+
)
120+
print(
121+
indent(
122+
f"'''",
123+
prefix=" " * ident * 2,
124+
)
125+
)
126+
127+
pass_name = pass_.argument
128+
if py_args:
129+
py_args_str = ", ".join([f"{n}: {t} = None" for n, t in py_args])
130+
print(
131+
indent(
132+
f"def {pass_name.replace('-', '_')}(self, {py_args_str}):",
133+
prefix=" " * ident,
134+
)
135+
)
136+
print_options_doc_string(pass_)
137+
138+
mlir_args = []
139+
for n, t in py_args:
140+
if "list" in t:
141+
print(
142+
indent(
143+
f"if {n} is not None and isinstance({n}, (list, tuple)):",
144+
prefix=" " * ident * 2,
145+
)
146+
)
147+
print(indent(f"{n} = ','.join(map(str, {n}))", prefix=" " * ident * 3))
148+
mlir_args.append(f"{n}={n}")
149+
print(
150+
indent(
151+
dedent(
152+
f"""\
153+
self.add_pass("{pass_name}", {', '.join(mlir_args)})
154+
return self
155+
"""
156+
),
157+
prefix=" " * ident * 2,
158+
)
159+
)
160+
161+
else:
162+
print(
163+
indent(
164+
dedent(
165+
f"""\
166+
def {pass_name.replace('-', '_')}(self):"""
167+
),
168+
prefix=" " * ident,
169+
)
170+
)
171+
print_options_doc_string(pass_)
172+
print(
173+
indent(
174+
dedent(
175+
f"""\
176+
self.add_pass("{pass_name}")
177+
return self
178+
"""
179+
),
180+
prefix=" " * ident * 2,
181+
)
182+
)
183+
184+
185+
def gather_passes_from_td_json(j):
186+
passes = []
187+
for pass_ in j["!instanceof"]["Pass"]:
188+
pass_ = j[pass_]
189+
options = []
190+
for o in pass_["options"]:
191+
option = j[o["def"]]
192+
option = Option(
193+
argument=option["argument"],
194+
description=option["description"],
195+
type=option["type"],
196+
additional_opt_flags=option["additionalOptFlags"],
197+
default_value=option["defaultValue"],
198+
list_option="ListOption" in option["!superclasses"],
199+
)
200+
options.append(option)
201+
pass_ = Pass(
202+
name=pass_["!name"],
203+
argument=pass_["argument"],
204+
options=options,
205+
description=pass_["description"],
206+
summary=pass_["summary"],
207+
)
208+
passes.append(pass_)
209+
210+
return passes
211+
212+
213+
if __name__ == "__main__":
214+
passes = []
215+
for td in glob.glob(str(include_path / "**" / "*.td"), recursive=True):
216+
try:
217+
j = dump_json(Path(td))
218+
if j["!instanceof"]["Pass"]:
219+
passes.extend(gather_passes_from_td_json(j))
220+
except:
221+
continue
222+
223+
for p in sorted(passes, key=lambda p: p.argument):
224+
generate_pass_method(p)
225+
226+
for p in sorted(passes, key=lambda p: p.argument):
227+
argument = p.argument.replace("-", "_")
228+
print(f"{argument} = Pipeline().{argument}")

0 commit comments

Comments
 (0)