Skip to content

Commit 830c7e1

Browse files
authored
test memref<vector> (#131)
1 parent 5078a84 commit 830c7e1

File tree

17 files changed

+819
-143
lines changed

17 files changed

+819
-143
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,13 @@ jobs:
6666
- name: Test
6767
shell: bash
6868
run: |
69-
if [ ${{ matrix.os }} == 'windows-2022' ]; then
70-
pytest -s tests
71-
else
72-
pytest --capture=tee-sys tests
73-
fi
69+
70+
pytest tests
7471
7572
- name: Test mwe
7673
shell: bash
7774
run: |
75+
7876
python examples/mwe.py
7977
8078
test-other-host-bindings:

mlir/extras/dialects/ext/arith.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import cached_property, partialmethod
77
from typing import Optional, Tuple, Union
88

9+
import numpy as np
910
from bytecode import ConcreteBytecode
1011
from einspect.structs import PyTypeObject
1112

@@ -124,18 +125,21 @@ def constant(
124125

125126

126127
def index_cast(
127-
value: Value,
128+
in_: Value,
128129
*,
129130
to: Type = None,
131+
out: Type = None,
130132
loc: Location = None,
131133
ip: InsertionPoint = None,
132134
) -> Value:
133135
if loc is None:
134136
loc = get_user_code_loc()
135-
if to is None:
136-
to = IndexType.get()
137+
assert bool(to) != bool(out), "either `to` or `out` must be set"
138+
res_type = out or to
139+
if res_type is None:
140+
res_type = IndexType.get()
137141
return get_op_result_or_op_results(
138-
arith_dialect.IndexCastOp(to, value, loc=loc, ip=ip)
142+
arith_dialect.IndexCastOp(res_type, in_, loc=loc, ip=ip)
139143
)
140144

141145

@@ -234,26 +238,26 @@ def __call__(cls, *args, **kwargs):
234238
@register_attribute_builder("Arith_CmpIPredicateAttr", replace=True)
235239
def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context):
236240
predicates = {
237-
"eq": CmpIPredicate.eq,
238-
"ne": CmpIPredicate.ne,
239-
"slt": CmpIPredicate.slt,
240-
"sle": CmpIPredicate.sle,
241-
"sgt": CmpIPredicate.sgt,
242-
"sge": CmpIPredicate.sge,
243-
"ult": CmpIPredicate.ult,
244-
"ule": CmpIPredicate.ule,
245-
"ugt": CmpIPredicate.ugt,
246-
"uge": CmpIPredicate.uge,
247-
0: CmpIPredicate.eq,
248-
1: CmpIPredicate.ne,
249-
2: CmpIPredicate.slt,
250-
3: CmpIPredicate.sle,
251-
4: CmpIPredicate.sgt,
252-
5: CmpIPredicate.sge,
253-
6: CmpIPredicate.ult,
254-
7: CmpIPredicate.ule,
255-
8: CmpIPredicate.ugt,
256-
9: CmpIPredicate.uge,
241+
"eq": arith_dialect.CmpIPredicate.eq,
242+
"ne": arith_dialect.CmpIPredicate.ne,
243+
"slt": arith_dialect.CmpIPredicate.slt,
244+
"sle": arith_dialect.CmpIPredicate.sle,
245+
"sgt": arith_dialect.CmpIPredicate.sgt,
246+
"sge": arith_dialect.CmpIPredicate.sge,
247+
"ult": arith_dialect.CmpIPredicate.ult,
248+
"ule": arith_dialect.CmpIPredicate.ule,
249+
"ugt": arith_dialect.CmpIPredicate.ugt,
250+
"uge": arith_dialect.CmpIPredicate.uge,
251+
0: arith_dialect.CmpIPredicate.eq,
252+
1: arith_dialect.CmpIPredicate.ne,
253+
2: arith_dialect.CmpIPredicate.slt,
254+
3: arith_dialect.CmpIPredicate.sle,
255+
4: arith_dialect.CmpIPredicate.sgt,
256+
5: arith_dialect.CmpIPredicate.sge,
257+
6: arith_dialect.CmpIPredicate.ult,
258+
7: arith_dialect.CmpIPredicate.ule,
259+
8: arith_dialect.CmpIPredicate.ugt,
260+
9: arith_dialect.CmpIPredicate.uge,
257261
}
258262
if isinstance(predicate, Attribute):
259263
return predicate
@@ -264,29 +268,29 @@ def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context)
264268
@register_attribute_builder("Arith_CmpFPredicateAttr", replace=True)
265269
def _arith_CmpFPredicateAttr(predicate: Union[str, Attribute], context: Context):
266270
predicates = {
267-
"false": CmpFPredicate.AlwaysFalse,
271+
"false": arith_dialect.CmpFPredicate.AlwaysFalse,
268272
# ordered comparison
269273
# An ordered comparison checks if neither operand is NaN.
270-
"oeq": CmpFPredicate.OEQ,
271-
"ogt": CmpFPredicate.OGT,
272-
"oge": CmpFPredicate.OGE,
273-
"olt": CmpFPredicate.OLT,
274-
"ole": CmpFPredicate.OLE,
275-
"one": CmpFPredicate.ONE,
274+
"oeq": arith_dialect.CmpFPredicate.OEQ,
275+
"ogt": arith_dialect.CmpFPredicate.OGT,
276+
"oge": arith_dialect.CmpFPredicate.OGE,
277+
"olt": arith_dialect.CmpFPredicate.OLT,
278+
"ole": arith_dialect.CmpFPredicate.OLE,
279+
"one": arith_dialect.CmpFPredicate.ONE,
276280
# no clue what this one is
277-
"ord": CmpFPredicate.ORD,
281+
"ord": arith_dialect.CmpFPredicate.ORD,
278282
# unordered comparison
279283
# Conversely, an unordered comparison checks if either operand is a NaN.
280-
"ueq": CmpFPredicate.UEQ,
281-
"ugt": CmpFPredicate.UGT,
282-
"uge": CmpFPredicate.UGE,
283-
"ult": CmpFPredicate.ULT,
284-
"ule": CmpFPredicate.ULE,
285-
"une": CmpFPredicate.UNE,
284+
"ueq": arith_dialect.CmpFPredicate.UEQ,
285+
"ugt": arith_dialect.CmpFPredicate.UGT,
286+
"uge": arith_dialect.CmpFPredicate.UGE,
287+
"ult": arith_dialect.CmpFPredicate.ULT,
288+
"ule": arith_dialect.CmpFPredicate.ULE,
289+
"une": arith_dialect.CmpFPredicate.UNE,
286290
# no clue what this one is
287-
"uno": CmpFPredicate.UNO,
291+
"uno": arith_dialect.CmpFPredicate.UNO,
288292
# return always true
289-
"true": CmpFPredicate.AlwaysTrue,
293+
"true": arith_dialect.CmpFPredicate.AlwaysTrue,
290294
}
291295
if isinstance(predicate, Attribute):
292296
return predicate

mlir/extras/dialects/ext/func.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
)
2525

2626

27+
_call = call
28+
29+
2730
def call(
2831
callee_or_results: Union[FuncOp, List[Type]],
2932
arguments_or_callee: Union[List[Value], FlatSymbolRefAttr, str],

mlir/extras/dialects/ext/memref.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._shaped_value import ShapedValue
77
from .arith import Scalar, constant
88
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
9+
from .vector import Vector
910
from ... import types as T
1011
from ...meta import region_op
1112
from ...util import (
@@ -28,7 +29,7 @@
2829
S = ShapedType.get_dynamic_size()
2930

3031

31-
def _alloc(
32+
def __alloc(
3233
op_ctor,
3334
sizes: Sequence[Union[int, Value]],
3435
element_type: Type,
@@ -64,6 +65,9 @@ def _alloc(
6465
)
6566

6667

68+
_alloc = alloc
69+
70+
6771
def alloc(
6872
sizes: Union[int, Value],
6973
element_type: Type = None,
@@ -74,7 +78,7 @@ def alloc(
7478
):
7579
if loc is None:
7680
loc = get_user_code_loc()
77-
return _alloc(
81+
return __alloc(
7882
AllocOp,
7983
sizes,
8084
element_type,
@@ -85,6 +89,9 @@ def alloc(
8589
)
8690

8791

92+
_alloca = alloca
93+
94+
8895
def alloca(
8996
sizes: Union[int, Value],
9097
element_type: Type = None,
@@ -95,7 +102,7 @@ def alloca(
95102
):
96103
if loc is None:
97104
loc = get_user_code_loc()
98-
return _alloc(
105+
return __alloc(
99106
AllocaOp,
100107
sizes,
101108
element_type,
@@ -106,26 +113,31 @@ def alloca(
106113
)
107114

108115

109-
def load(mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None):
116+
def load(memref: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None):
110117
if loc is None:
111118
loc = get_user_code_loc()
112119
indices = list(indices)
113120
for idx, i in enumerate(indices):
114121
if isinstance(i, int):
115122
indices[idx] = constant(i, index=True)
116-
return get_op_result_or_op_results(LoadOp(mem, indices, loc=loc, ip=ip))
123+
return get_op_result_or_op_results(LoadOp(memref, indices, loc=loc, ip=ip))
117124

118125

119126
def store(
120-
value: Value, mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None
127+
value: Value,
128+
memref: Value,
129+
indices: Sequence[Union[Value, int]],
130+
*,
131+
loc=None,
132+
ip=None,
121133
):
122134
if loc is None:
123135
loc = get_user_code_loc()
124136
indices = list(indices)
125137
for idx, i in enumerate(indices):
126138
if isinstance(i, int):
127139
indices[idx] = constant(i, index=True)
128-
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))
140+
return get_op_result_or_op_results(StoreOp(value, memref, indices, loc=loc, ip=ip))
129141

130142

131143
@register_value_caster(MemRefType.static_typeid)
@@ -176,7 +188,9 @@ def __setitem__(self, idx, val):
176188
if isinstance(val, (int, float)):
177189
# TODO: this is an unchecked conversion
178190
val = Scalar(val, dtype=self.dtype)
179-
assert isinstance(val, Scalar), "coordinate insert requires scalar element"
191+
assert isinstance(
192+
val, (Scalar, Vector)
193+
), "coordinate insert requires scalar element"
180194
store(val, self, idx, loc=loc)
181195
else:
182196
_copy_to_subview(self, val, tuple(idx), loc=loc)

mlir/extras/dialects/ext/scf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def placeholder_opaque_t():
105105
return opaque("scf", "placeholder")
106106

107107

108-
for_ = region_op(_build_for, terminator=yield__)
108+
for__ = region_op(_build_for, terminator=yield__)
109109

110110

111111
@_cext.register_operation(_Dialect, replace=True)
@@ -365,7 +365,11 @@ def another_reduce(reduce_op):
365365
return r
366366

367367

368-
def yield_(*args):
368+
def yield_(*args, results_=None):
369+
if len(args):
370+
assert results_ is None, "must provide results_ or args"
371+
if results_ is not None:
372+
args = results_
369373
if len(args) == 1 and isinstance(args[0], (list, OpResultList)):
370374
args = list(args[0])
371375
y = yield__(args)

mlir/extras/dialects/ext/vector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,13 @@ def extract(vector, position, *, loc=None, ip=None):
142142
_insert = insert
143143

144144

145-
def insert(vector, val, position, *, loc=None, ip=None):
145+
def insert(vector, val, positions, *, loc=None, ip=None):
146146
if loc is None:
147147
loc = get_user_code_loc()
148+
if len(positions) == 0:
149+
raise ValueError("positions cannot be empty")
148150
dynamic_position, _packed_position, static_position = _dispatch_mixed_values(
149-
position
151+
positions
150152
)
151153
return _insert(
152154
val,

mlir/extras/runtime/passes.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,50 @@ def add_pass(self, pass_name, **kwargs):
135135
self._pipeline.append(pass_str)
136136
return self
137137

138-
def lower_to_llvm_(self):
139-
return any(["to-llvm" in p for p in self._pipeline])
140-
141-
def bufferize(self):
142-
return (
143-
self.Func(Pipeline().empty_tensor_to_alloc_tensor())
144-
.one_shot_bufferize()
145-
.Func(Pipeline().buffer_deallocation_simplification())
146-
)
147-
148138
def lower_to_llvm(self):
139+
# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44
149140
return (
150-
self.cse()
151-
.Func(Pipeline().lower_affine().arith_expand().convert_math_to_llvm())
152-
.convert_math_to_libm()
153-
.expand_strided_metadata()
154-
.finalize_memref_to_llvm()
141+
self.Func(
142+
Pipeline()
143+
# Blanket-convert any remaining high-level vector ops to loops if any remain.
144+
.convert_vector_to_scf()
145+
# Blanket-convert any remaining linalg ops to loops if any remain.
146+
.convert_linalg_to_loops()
147+
)
148+
# Blanket-convert any remaining affine ops if any remain.
149+
.lower_affine()
150+
# Convert SCF to CF (always needed).
155151
.convert_scf_to_cf()
156-
.convert_cf_to_llvm()
152+
# Sprinkle some cleanups.
153+
.canonicalize()
157154
.cse()
155+
# Convert vector to LLVM (always needed).
156+
.convert_vector_to_llvm(force_32bit_vector_indices=True)
157+
# Convert Math to LLVM (always needed).
158+
.Func(Pipeline().convert_math_to_llvm())
159+
# Expand complicated MemRef operations before lowering them.
160+
.expand_strided_metadata()
161+
# The expansion may create affine expressions. Get rid of them.
158162
.lower_affine()
159-
.Func(Pipeline().convert_arith_to_llvm())
163+
# Convert MemRef to LLVM (always needed).
164+
.finalize_memref_to_llvm()
165+
# Convert Func to LLVM (always needed).
160166
.convert_func_to_llvm()
161-
.canonicalize()
162-
.convert_openmp_to_llvm()
163-
.cse()
167+
.convert_arith_to_llvm()
168+
.convert_cf_to_llvm()
169+
# Convert Index to LLVM (always needed).
170+
.convert_index_to_llvm()
171+
# Convert remaining unrealized_casts (always needed).
164172
.reconcile_unrealized_casts()
165173
)
166174

175+
def bufferize(self):
176+
return (
177+
self.Func(Pipeline().empty_tensor_to_alloc_tensor())
178+
.one_shot_bufferize()
179+
.Func(Pipeline().buffer_deallocation_simplification())
180+
)
181+
167182
def lower_to_openmp(self):
168183
return self.convert_scf_to_openmp().Func(Pipeline().lower_affine())
169184

mlir/extras/runtime/refbackend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def cb(op):
238238
return False
239239

240240
kernel_func = find_ops(module.operation, cb, single=True)
241+
if isinstance(kernel_func, list) and len(kernel_func) == 0:
242+
raise ValueError(f"couldn't find kernel_func {kernel_name=}")
241243
if len(kernel_func.function_type.value.results) and generate_return_consumer:
242244
with InsertionPoint(module.body):
243245
return_consumer = make_return_consumer(kernel_func)

mlir/extras/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,4 @@ def __getitem__(self, item):
437437

438438
# f is not a bound method since it was decorated...
439439
return self.f(self.instance, item, **kwargs)
440+

0 commit comments

Comments
 (0)