Skip to content

Commit e13403a

Browse files
committed
add refbackend and passes (todo: generate passes but needs upstream patch)
1 parent 576c4eb commit e13403a

File tree

8 files changed

+2289
-3
lines changed

8 files changed

+2289
-3
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import re
2+
from textwrap import dedent, indent
3+
4+
from yaml import safe_load
5+
6+
try:
7+
from yaml import CLoader as Loader, CDumper as Dumper
8+
except ImportError:
9+
from yaml import Loader, Dumper
10+
11+
import multiprocess as multiprocessing
12+
import os
13+
14+
from .._mlir._mlir_libs._nelli_mlir import print_help
15+
16+
17+
def capture_help():
18+
def method(w):
19+
os.dup2(w.fileno(), 1)
20+
print_help()
21+
22+
r, w = multiprocessing.Pipe()
23+
24+
reader = os.fdopen(r.fileno(), "r")
25+
26+
process = multiprocessing.Process(None, method, "print_help", (w,))
27+
process.start()
28+
29+
lines = []
30+
31+
capturing = False
32+
33+
for i in range(1000):
34+
line = reader.readline()
35+
if "Passes:" in line:
36+
capturing = True
37+
continue
38+
if "Pass Pipelines:" in line:
39+
break
40+
if capturing:
41+
lines.append(line)
42+
43+
# you don't need these because help already exits
44+
# reader.close()
45+
# process.join()
46+
47+
# print("".join(lines))
48+
return lines
49+
50+
51+
def fixup_lines_into_yaml(lines):
52+
i = 0
53+
while i < len(lines):
54+
l = lines[i]
55+
prev_line = lines[i - 1]
56+
leading_spaces = len(prev_line) - len(prev_line.strip())
57+
while l.strip().startswith("="):
58+
lines[i] = " " * leading_spaces + l
59+
i += 1
60+
l = lines[i]
61+
62+
i += 1
63+
64+
lines = dedent("".join(lines)).splitlines(keepends=True)
65+
for i, l in enumerate(lines):
66+
l = l.replace(" - ", " # ")
67+
if l.startswith("--"):
68+
l = re.sub(r"^--((\w|-)+)", r"\1:", l, re.M)
69+
l = re.sub(r"^ {2}--", " - ", l, re.M)
70+
l = re.sub(r"^ {7}=", " - ", l, re.M)
71+
l = l.replace("=<value>", "=<value>:")
72+
lines[i] = l
73+
74+
# with open("passes_new.txt", "w") as f:
75+
# f.write("".join(lines))
76+
yml = safe_load("".join(lines))
77+
return yml
78+
79+
80+
def parse_passes(yml):
81+
illegal_names = {"global": "global_"}
82+
illegal_names_back = {"global_": "global"}
83+
ident = 4
84+
all_types = set()
85+
for pass_name, maybe_args in yml.items():
86+
# print(k, v)
87+
py_args = []
88+
if maybe_args is None:
89+
pass
90+
else:
91+
for a in maybe_args:
92+
if isinstance(a, str):
93+
if "=" in a:
94+
name, typ = a.split("=")
95+
all_types.add(typ)
96+
match typ:
97+
case item if item in {
98+
"<long>",
99+
"<ulong>",
100+
"<number>",
101+
"<uint>",
102+
"<int>",
103+
}:
104+
arg_typ = "int"
105+
case "<string>" | "<pass-manager>":
106+
arg_typ = "str"
107+
case _:
108+
raise RuntimeError(a)
109+
else:
110+
name = a
111+
arg_typ = "bool"
112+
name = illegal_names.get(name, name)
113+
arg_name = f"{name.replace('-', '_')}"
114+
py_args.append((arg_name, arg_typ))
115+
else:
116+
assert isinstance(a, dict)
117+
assert len(a) == 1
118+
name, _typ = list(a.keys())[0].split("=")
119+
arg_name = f"{name.replace('-', '_')}"
120+
py_args.append((arg_name, "str"))
121+
122+
if py_args:
123+
py_args_str = ", ".join([f"{n}=None" for n, t in py_args])
124+
print(
125+
indent(
126+
f"def {pass_name.replace('-', '_')}(self, {py_args_str}):",
127+
prefix=" " * ident,
128+
)
129+
)
130+
mlir_args = []
131+
for n, t in py_args:
132+
if t in {"int", "str"}:
133+
print(
134+
indent(
135+
f"if {n} is not None and isinstance({n}, (list, tuple)):",
136+
prefix=" " * ident * 2,
137+
)
138+
)
139+
print(
140+
indent(f"{n} = ','.join(map(str, {n}))", prefix=" " * ident * 3)
141+
)
142+
mlir_args.append(f"{n}={n}")
143+
print(
144+
indent(
145+
dedent(
146+
f"""\
147+
self._add_pass("{pass_name}", {', '.join(mlir_args)})
148+
return self
149+
"""
150+
),
151+
prefix=" " * ident * 2,
152+
)
153+
)
154+
155+
else:
156+
print(
157+
indent(
158+
dedent(
159+
f"""\
160+
def {pass_name.replace('-', '_')}(self):
161+
self._add_pass("{pass_name}")
162+
return self
163+
"""
164+
),
165+
prefix=" " * ident,
166+
)
167+
)
168+
169+
170+
if __name__ == "__main__":
171+
lines = capture_help()
172+
yml = fixup_lines_into_yaml(lines)
173+
parse_passes(yml)

mlir_utils/ast/util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ def bind(func, instance, as_name=None):
4646
def copy_func(f, new_code):
4747
"""Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
4848
g = types.FunctionType(
49-
new_code,
50-
f.__globals__,
49+
code=new_code,
50+
globals={
51+
**f.__globals__,
52+
**{
53+
fr: f.__closure__[i].cell_contents
54+
for i, fr in enumerate(f.__code__.co_freevars)
55+
},
56+
},
5157
name=f.__name__,
5258
argdefs=f.__defaults__,
53-
closure=f.__closure__,
59+
# TODO(max): ValueError: foo requires closure of length 0, not 1
60+
# closure=f.__closure__,
5461
)
5562
g.__kwdefaults__ = f.__kwdefaults__
5663
g.__dict__.update(f.__dict__)

mlir_utils/dialects/ext/func.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,45 @@ def func(
144144
loc=loc,
145145
ip=ip,
146146
)
147+
148+
149+
def call(symbol_name, call_args, return_types, *, loc=None, ip=None):
150+
if loc is None:
151+
loc = get_user_code_loc()
152+
return maybe_cast(
153+
get_result_or_results(
154+
CallOp.__base__(
155+
return_types,
156+
FlatSymbolRefAttr.get(symbol_name),
157+
call_args,
158+
loc=loc,
159+
ip=ip,
160+
)
161+
)
162+
)
163+
164+
165+
def declare(
166+
symbol_name,
167+
input_types: list,
168+
result_types=None,
169+
func_op_ctor=FuncOp,
170+
):
171+
if result_types is None:
172+
result_types = []
173+
assert all(
174+
isinstance(a, Type) for a in input_types
175+
), f"wrong func args {input_types}"
176+
assert all(
177+
isinstance(a, Type) for a in result_types
178+
), f"wrong func results {result_types}"
179+
180+
function_type = FunctionType.get(inputs=input_types, results=result_types)
181+
sym_name = func_op_ctor(
182+
name=symbol_name, type=function_type, visibility="private"
183+
).sym_name
184+
185+
def callable(*call_args):
186+
return call(sym_name.value, call_args, result_types)
187+
188+
return callable

mlir_utils/runtime/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)