Skip to content

Commit 89858c2

Browse files
committed
tiling examples
1 parent 9627179 commit 89858c2

File tree

10 files changed

+527
-44
lines changed

10 files changed

+527
-44
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
if [ ${{ matrix.os }} == 'windows-2022' ]; then
5050
pytest -s --ignore-glob=*test_smoke* tests
5151
else
52-
pytest -s --ignore-glob=*test_smoke* tests
52+
pytest --capture=tee-sys --ignore-glob=*test_smoke* tests
5353
fi
5454
5555
test-against-torch-mlir-bindings:
@@ -86,9 +86,9 @@ jobs:
8686
shell: bash
8787
run: |
8888
if [ ${{ matrix.os }} == 'windows-2022' ]; then
89-
pytest -s tests/test_smoke.py
89+
pytest --capture=tee-sys tests/test_smoke.py
9090
else
91-
pytest -s tests/test_smoke.py
91+
pytest --capture=tee-sys tests/test_smoke.py
9292
fi
9393
9494
@@ -126,7 +126,7 @@ jobs:
126126
shell: bash
127127
run: |
128128
if [ ${{ matrix.os }} == 'windows-2022' ]; then
129-
pytest -s tests/test_smoke.py
129+
pytest --capture=tee-sys tests/test_smoke.py
130130
else
131-
pytest -s tests/test_smoke.py
131+
pytest --capture=tee-sys tests/test_smoke.py
132132
fi

mlir_utils/ast/canonicalize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def transform_ast(
7575
n_lines = len(inspect.getsource(f).splitlines())
7676
line_starts = list(findlinestarts(new_f_code_o))
7777
assert (
78-
line_starts[-1][1] - line_starts[0][1] == n_lines - 1
78+
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
79+
== n_lines
7980
), f"something went wrong with the line numbers for the rewritten/canonicalized function"
8081
return copy_func(f, new_f_code_o)
8182

mlir_utils/dialects/ext/func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def func(
247247
) -> FuncBase:
248248
if loc is None:
249249
loc = get_user_code_loc()
250-
return FuncBase(
250+
func = FuncBase(
251251
body_builder=f,
252252
func_op_ctor=FuncOp.__base__,
253253
return_op_ctor=ReturnOp,
@@ -259,3 +259,5 @@ def func(
259259
loc=loc,
260260
ip=ip,
261261
)
262+
func.__name__ = f.__name__
263+
return func

mlir_utils/dialects/ext/memref.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
from mlir.ir import Type, Value, MemRefType, ShapedType, MLIRError
66

7-
from mlir_utils.dialects import memref
7+
import mlir_utils.types as T
8+
from mlir_utils.dialects import memref, arith
89
from mlir_utils.dialects.ext.arith import Scalar, constant
910
from mlir_utils.dialects.ext.tensor import (
1011
_indices_to_indexer,
1112
compute_result_shape_reassoc_list,
1213
)
13-
import mlir_utils.types as T
1414
from mlir_utils.util import (
1515
register_value_caster,
1616
get_user_code_loc,
@@ -88,6 +88,8 @@ def store(
8888

8989
def subview(
9090
source: "MemRef",
91+
offsets: Optional[Sequence[Value]] = None,
92+
strides: Optional[Sequence[Value]] = None,
9193
static_offsets: Optional[Sequence[int]] = None,
9294
static_sizes: Optional[Sequence[int]] = None,
9395
static_strides: Optional[Sequence[int]] = None,
@@ -97,11 +99,23 @@ def subview(
9799
):
98100
if loc is None:
99101
loc = get_user_code_loc()
102+
if offsets is None:
103+
offsets = []
104+
if static_offsets is None:
105+
static_offsets = []
106+
if strides is None:
107+
strides = []
108+
if static_strides is None:
109+
static_strides = []
100110
assert static_sizes, f"this convenience method only handles static sizes"
101-
offsets = sizes = strides = []
102-
result = T.memref(*static_sizes, source.dtype)
111+
sizes = []
112+
wrong_type = T.memref(*static_sizes, source.dtype)
113+
if offsets and static_offsets:
114+
assert all(s == S for s in static_offsets)
115+
if strides and static_strides:
116+
assert all(s == S for s in static_strides)
103117
val = memref.subview(
104-
result,
118+
wrong_type,
105119
source,
106120
offsets,
107121
sizes,
@@ -270,7 +284,51 @@ def _subview(
270284
ip=ip,
271285
)
272286
else:
273-
raise ValueError(f"non-constant indices not supported {indexer}")
287+
# special tile case
288+
offsets = [None] * len(indexer.in_shape)
289+
static_offsets = [None] * len(indexer.in_shape)
290+
static_sizes = [None] * len(indexer.in_shape)
291+
static_strides = [None] * len(indexer.in_shape)
292+
for i, ind in enumerate(indexer.indices):
293+
maybe_size = maybe_cast(ind.stop.owner.operands[1])
294+
if (
295+
isinstance(ind.start.owner.opview, arith.MulIOp)
296+
and isinstance(ind.stop.owner.opview, arith.MulIOp)
297+
and isinstance(ind.stop.owner.operands[0].owner.opview, arith.AddIOp)
298+
and ind.start.owner.operands[0]
299+
== ind.stop.owner.operands[0].owner.operands[0]
300+
and maybe_size.is_constant()
301+
and isinstance(ind.step, int)
302+
or isinstance(ind.step, Scalar)
303+
and ind.step.is_constant()
304+
):
305+
offsets[i] = ind.start
306+
static_offsets[i] = S
307+
static_sizes[i] = maybe_size.literal_value
308+
static_strides[i] = (
309+
ind.step.literal_value if isinstance(ind.step, Scalar) else ind.step
310+
)
311+
else:
312+
raise RuntimeError(f"indexing not supported {indexer.indices}")
313+
offsets = list(filter(None, offsets))
314+
static_offsets = list(filter(None, static_offsets))
315+
static_sizes = list(filter(None, static_sizes))
316+
static_strides = list(filter(None, static_strides))
317+
assert (
318+
len(offsets)
319+
== len(static_sizes)
320+
== len(static_strides)
321+
== len(indexer.in_shape)
322+
), f"not each slice is statically known: {indexer.indices}"
323+
out = subview(
324+
out,
325+
offsets=offsets,
326+
static_offsets=static_offsets,
327+
static_sizes=static_sizes,
328+
static_strides=static_strides,
329+
loc=loc,
330+
ip=ip,
331+
)
274332

275333
# This adds newaxis/None dimensions.
276334
return expand_shape(out, indexer.newaxis_dims, loc=loc, ip=ip)

mlir_utils/dialects/ext/tensor.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,14 @@ def static_sizes(self):
268268
sizes.append(1)
269269
elif isinstance(i, slice):
270270
start, stop, step = map(int, (i.start, i.stop, i.step))
271-
s = ((stop - start) // step) + 1
272-
if (stop - start) % step == 0:
273-
s -= 1
274-
sizes.append(s)
271+
if all(isinstance(j, int) for j in (start, stop, step)):
272+
s = ((stop - start) // step) + 1
273+
if (stop - start) % step == 0:
274+
s -= 1
275+
sizes.append(s)
276+
else:
277+
raise ValueError(f"idx {i} not supported with static sizes")
278+
275279
else:
276280
raise ValueError(f"idx {i} not supported with static sizes")
277281
return tuple(sizes)
@@ -496,12 +500,12 @@ def _indices_to_indexer(
496500
elif isinstance(idx_e, slice):
497501
# Normalize the slice to use None when possible
498502
start, stop, step = idx_e.start, idx_e.stop, idx_e.step
499-
if step is None or step == 1:
503+
if step is None or isinstance(step, int) and step == 1:
500504
step = None
501505
if step is None:
502-
if start is None or start == 0:
506+
if start is None or isinstance(start, int) and start == 0:
503507
start = None
504-
if stop is None or stop >= in_shape[in_axis]:
508+
if stop is None or isinstance(stop, int) and stop >= in_shape[in_axis]:
505509
stop = None
506510
# Handle slice(None) and slice(None, None, -1)
507511
if (
@@ -529,6 +533,8 @@ def _indices_to_indexer(
529533
)
530534
raise IndexError(msg)
531535

536+
if step is None:
537+
step = 1
532538
indices[in_axis] = slice(start, stop, step)
533539

534540
out_axis += 1

mlir_utils/runtime/refbackend.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,35 @@
11
import ctypes
22
import logging
33
import os
4+
import platform
45
import warnings
56
from pathlib import Path
67

78
import numpy as np
89
from mlir import _mlir_libs
910
from mlir.dialects.func import FuncOp, CallOp
10-
from mlir.execution_engine import ExecutionEngine
1111
from mlir.ir import UnitAttr, Module, MemRefType
12-
from mlir.runtime import (
13-
UnrankedMemRefDescriptor,
14-
get_ranked_memref_descriptor,
15-
unranked_memref_to_numpy,
16-
get_unranked_memref_descriptor,
17-
)
1812

13+
try:
14+
from mlir.execution_engine import ExecutionEngine
15+
from mlir.runtime import (
16+
UnrankedMemRefDescriptor,
17+
get_ranked_memref_descriptor,
18+
unranked_memref_to_numpy,
19+
get_unranked_memref_descriptor,
20+
)
21+
except:
22+
warnings.warn("no execution engine in mlir bindings; refbackend won't work")
23+
24+
25+
import mlir_utils.types as T
1926
from mlir_utils.dialects.memref import cast
2027
from mlir_utils.runtime.passes import Pipeline, run_pipeline
21-
from mlir_utils.types import memref_type_to_np_dtype, mlir_type_to_ctype, np_dtype_to_mlir_type
22-
import mlir_utils.types as T
28+
from mlir_utils.types import (
29+
memref_type_to_np_dtype,
30+
mlir_type_to_ctype,
31+
np_dtype_to_mlir_type,
32+
)
2333
from mlir_utils.util import shlib_ext, find_ops, shlib_prefix
2434

2535
logger = logging.getLogger(__name__)
@@ -86,8 +96,8 @@ def convert_arg_to_ctype(arg, unranked=True):
8696
if isinstance(arg, CData) or isinstance(arg, (int, float, bool)):
8797
return arg
8898
elif isinstance(arg, np.ndarray):
89-
assert (
90-
np_dtype_to_mlir_type(arg.dtype.type)
99+
assert np_dtype_to_mlir_type(
100+
arg.dtype.type
91101
), f"unsupported numpy array type {arg.dtype}"
92102
if unranked:
93103
return ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg)))
@@ -145,6 +155,8 @@ def invoke(*args):
145155
self.ee.invoke(
146156
function_name, *[convert_arg_to_ctype(a, unranked=False) for a in args]
147157
)
158+
if self.results is not None and len(self.results) == 1:
159+
return self.results[0]
148160
return self.results
149161

150162
return invoke
@@ -193,7 +205,7 @@ def __init__(
193205
self,
194206
shared_lib_paths=None,
195207
):
196-
if shared_lib_paths is None:
208+
if shared_lib_paths is None and platform.system() != "Windows":
197209
shared_lib_paths = [
198210
ASYNC_RUNTIME_LIB_PATH,
199211
C_RUNNER_UTILS_LIB_PATH,

mlir_utils/types.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
import ctypes
32
from functools import partial
43
from typing import Union
@@ -24,6 +23,7 @@
2423
UnrankedMemRefType,
2524
UnrankedTensorType,
2625
VectorType,
26+
StridedLayoutAttr,
2727
)
2828

2929
_index = lambda: IndexType.get()
@@ -228,9 +228,17 @@ def tensor(*args, element_type: Type = None):
228228
)
229229

230230

231-
def memref(*args, element_type: Type = None, memory_space: int = None):
231+
def memref(
232+
*args,
233+
element_type: Type = None,
234+
memory_space: int = None,
235+
layout: tuple[tuple[int, ...], int] = None,
236+
):
232237
if memory_space is None:
233238
memory_space = 0
239+
if layout is not None:
240+
strides, offset = layout
241+
layout = StridedLayoutAttr.get(offset, strides)
234242
memory_space = Attribute.parse(str(memory_space))
235243
if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
236244
return shaped(
@@ -242,7 +250,9 @@ def memref(*args, element_type: Type = None, memory_space: int = None):
242250
return shaped(
243251
*args,
244252
element_type=element_type,
245-
type_constructor=partial(MemRefType.get, memory_space=memory_space),
253+
type_constructor=partial(
254+
MemRefType.get, memory_space=memory_space, layout=layout
255+
),
246256
)
247257

248258

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ dependencies = [
1111
"PyYAML",
1212
"inflection",
1313
"bytecode",
14-
"libcst",
1514
]
1615

1716
[project.optional-dependencies]

0 commit comments

Comments
 (0)