Skip to content

Commit 9627179

Browse files
committed
auto wrapper generation
1 parent 447d340 commit 9627179

File tree

9 files changed

+416
-244
lines changed

9 files changed

+416
-244
lines changed

.github/workflows/test.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ jobs:
4747
shell: bash
4848
run: |
4949
if [ ${{ matrix.os }} == 'windows-2022' ]; then
50-
pytest --ignore-glob=*test_smoke* tests
50+
pytest -s --ignore-glob=*test_smoke* tests
5151
else
52-
pytest --ignore-glob=*test_smoke* tests
52+
pytest -s --ignore-glob=*test_smoke* tests
5353
fi
5454
5555
test-against-torch-mlir-bindings:
@@ -86,9 +86,9 @@ jobs:
8686
shell: bash
8787
run: |
8888
if [ ${{ matrix.os }} == 'windows-2022' ]; then
89-
pytest tests/test_smoke.py
89+
pytest -s tests/test_smoke.py
9090
else
91-
pytest tests/test_smoke.py
91+
pytest -s tests/test_smoke.py
9292
fi
9393
9494
@@ -126,7 +126,7 @@ jobs:
126126
shell: bash
127127
run: |
128128
if [ ${{ matrix.os }} == 'windows-2022' ]; then
129-
pytest tests/test_smoke.py
129+
pytest -s tests/test_smoke.py
130130
else
131-
pytest tests/test_smoke.py
131+
pytest -s tests/test_smoke.py
132132
fi

mlir_utils/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ def mlir_mod_ctx(
3737
ip = mlir.ir.InsertionPoint(module.body)
3838
stack.enter_context(ip)
3939
yield MLIRContext(context, module)
40+
context._clear_live_operations()

mlir_utils/dialects/ext/tensor.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,15 @@ def __getitem__(self, idx: tuple) -> "Tensor":
144144
if not self.has_rank():
145145
raise ValueError("only ranked tensor slicing/indexing supported")
146146

147+
if idx is None:
148+
return expand_dims(self, (0,), loc=loc)
147149
if idx == Ellipsis or idx == slice(None):
148150
return self
149151
if isinstance(idx, tuple) and all(i == slice(None) for i in idx):
150152
return self
151-
if idx is None:
152-
return expand_dims(self, (0,), loc=loc)
153+
if isinstance(idx, tuple) and all(i == slice(None) or i is None for i in idx):
154+
nones = [i for i, n in enumerate(idx) if n is None]
155+
return expand_dims(self, nones, loc=loc)
153156

154157
idx = list((idx,) if isinstance(idx, int) else idx)
155158
for i, d in enumerate(idx):
@@ -237,6 +240,13 @@ class _Indexer:
237240
def is_constant(self):
238241
return all(_is_constant_index(i) for i in self.indices)
239242

243+
def is_full(self):
244+
return all(
245+
isinstance(idx, slice)
246+
and len(range(*idx.indices(self.in_shape[i]))) == self.in_shape[i]
247+
for i, idx in enumerate(self.indices)
248+
)
249+
240250
# waiting on hashable slices in 3.12 https://stackoverflow.com/a/76562346
241251
# @lru_cache(maxsize=1)
242252
def static_offsets(self):
@@ -551,7 +561,9 @@ def _extract_slice(
551561
indexer = _indices_to_indexer(idx, ten.shape)
552562
out = ten
553563

554-
if indexer.is_constant():
564+
if indexer.is_full():
565+
out = out
566+
elif indexer.is_constant():
555567
out = extract_slice(
556568
out,
557569
static_offsets=indexer.static_offsets(),

mlir_utils/runtime/passes.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,82 @@
1-
from __future__ import annotations
2-
31
import logging
2+
import os
3+
import sys
4+
import tempfile
5+
from contextlib import ExitStack
6+
from io import StringIO
7+
from typing import Optional
8+
9+
from mlir.ir import StringAttr
10+
from mlir.passmanager import PassManager
11+
12+
from mlir_utils.util import disable_multithreading
413

514
logger = logging.getLogger(__name__)
615

716

17+
class MlirCompilerError(Exception):
18+
pass
19+
20+
21+
def get_module_name_for_debug_dump(module):
22+
if "debug_module_name" not in module.operation.attributes:
23+
return "UnnammedModule"
24+
return StringAttr(module.operation.attributes["debug_module_name"]).value
25+
26+
27+
def run_pipeline(
28+
module,
29+
pipeline: str,
30+
description: Optional[str] = None,
31+
enable_ir_printing=False,
32+
print_pipeline=False,
33+
):
34+
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
35+
module_name = get_module_name_for_debug_dump(module)
36+
try:
37+
original_stderr = sys.stderr
38+
sys.stderr = StringIO()
39+
# Lower module in place to make it ready for compiler backends.
40+
with ExitStack() as stack:
41+
stack.enter_context(module.context)
42+
asm_for_error_report = module.operation.get_asm(
43+
large_elements_limit=10,
44+
enable_debug_info=True,
45+
)
46+
pm = PassManager.parse(pipeline)
47+
if print_pipeline:
48+
print(pm)
49+
if enable_ir_printing:
50+
stack.enter_context(disable_multithreading())
51+
pm.enable_ir_printing()
52+
53+
pm.run(module.operation)
54+
except Exception as e:
55+
print(e, file=sys.stderr)
56+
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
57+
with open(filename, "w") as f:
58+
f.write(asm_for_error_report)
59+
debug_options = "-mlir-print-ir-after-all -mlir-disable-threading"
60+
description = description or f"{module_name} compile"
61+
62+
message = f"""\
63+
{description} failed with the following diagnostics:
64+
65+
{'*' * 80}
66+
{sys.stderr.getvalue().strip()}
67+
{'*' * 80}
68+
69+
For developers, the error can be reproduced with:
70+
$ mlir-opt {debug_options} -pass-pipeline='{pipeline}' {filename}
71+
"""
72+
trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
73+
raise MlirCompilerError(trimmed_message)
74+
finally:
75+
sys.stderr = original_stderr
76+
77+
return module
78+
79+
880
class Pipeline:
981
_pipeline: list[str] = []
1082

@@ -13,17 +85,17 @@ def __init__(self, pipeline=None, wrapper=None):
1385
pipeline = []
1486
self._pipeline = pipeline
1587

16-
def Func(self, p: Pipeline):
88+
def Func(self, p: "Pipeline"):
1789
assert isinstance(p, Pipeline)
1890
self._pipeline.append(f"func.func({p.materialize(module=False)})")
1991
return self
2092

21-
def Spirv(self, p: Pipeline):
93+
def Spirv(self, p: "Pipeline"):
2294
assert isinstance(p, Pipeline)
2395
self._pipeline.append(f"spirv.module({p.materialize(module=False)})")
2496
return self
2597

26-
def Gpu(self, p: Pipeline):
98+
def Gpu(self, p: "Pipeline"):
2799
assert isinstance(p, Pipeline)
28100
self._pipeline.append(f"gpu.module({p.materialize(module=False)})")
29101
return self
@@ -38,13 +110,6 @@ def materialize(self, module=True):
38110
def __str__(self):
39111
return self.materialize()
40112

41-
def __add__(self, other: Pipeline):
42-
return Pipeline(self._pipeline + other._pipeline)
43-
44-
def __iadd__(self, other: Pipeline):
45-
self._pipeline += other._pipeline
46-
return self
47-
48113
def add_pass(self, pass_name, **kwargs):
49114
kwargs = {
50115
k.replace("_", "-"): int(v) if isinstance(v, bool) else v
@@ -57,6 +122,7 @@ def add_pass(self, pass_name, **kwargs):
57122
else:
58123
pass_str = f"{pass_name}"
59124
self._pipeline.append(pass_str)
125+
return self
60126

61127
def lower_to_llvm_(self):
62128
return any(["to-llvm" in p for p in self._pipeline])

0 commit comments

Comments
 (0)