Skip to content

Commit b992f9c

Browse files
committed
implement memref (and chores on tensor)
1 parent 2112e20 commit b992f9c

File tree

5 files changed

+757
-36
lines changed

5 files changed

+757
-36
lines changed

mlir_utils/dialects/ext/arith.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,14 @@ def __call__(cls, *args, **kwargs):
203203
fold = kwargs.get("fold")
204204
if fold is not None and not isinstance(fold, bool):
205205
raise ValueError(f"{fold=} is expected to be a bool.")
206+
loc = kwargs.get("loc")
207+
ip = kwargs.get("ip")
206208
# If we're wrapping a numpy array (effectively a tensor literal),
207209
# then we want to make sure no one else has access to that memory.
208210
# Otherwise, the array will get funneled down to DenseElementsAttr.get,
209211
# which by default (through the Python buffer protocol) does not copy;
210212
# see mlir/lib/Bindings/Python/IRAttributes.cpp#L556
211-
val = constant(deepcopy(arg), dtype)
213+
val = constant(deepcopy(arg), dtype, loc=loc, ip=ip)
212214
else:
213215
raise NotImplementedError(f"{cls.__name__} doesn't support wrapping {arg}.")
214216

mlir_utils/dialects/ext/memref.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import re
2+
from functools import cached_property
3+
from typing import Tuple, Sequence, Optional, Union
4+
5+
from mlir.ir import Type, Value, MemRefType, ShapedType, MLIRError
6+
7+
from mlir_utils.dialects import memref
8+
from mlir_utils.dialects.ext.arith import Scalar, constant
9+
from mlir_utils.dialects.ext.tensor import (
10+
_indices_to_indexer,
11+
compute_result_shape_reassoc_list,
12+
)
13+
import mlir_utils.types as T
14+
from mlir_utils.util import (
15+
register_value_caster,
16+
get_user_code_loc,
17+
maybe_cast,
18+
get_result_or_results,
19+
)
20+
21+
S = ShapedType.get_dynamic_size()
22+
23+
24+
def _alloc(
25+
op_ctor,
26+
sizes: Sequence[Union[int]],
27+
element_type: Type,
28+
*,
29+
loc=None,
30+
ip=None,
31+
):
32+
if loc is None:
33+
loc = get_user_code_loc()
34+
dynamic_sizes = []
35+
result_type = T.memref(*sizes, element_type)
36+
return maybe_cast(
37+
get_result_or_results(op_ctor(result_type, dynamic_sizes, [], loc=loc, ip=ip))
38+
)
39+
40+
41+
def alloc(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None):
42+
if loc is None:
43+
loc = get_user_code_loc()
44+
return maybe_cast(
45+
get_result_or_results(
46+
_alloc(memref.AllocOp, sizes, element_type, loc=loc, ip=ip)
47+
)
48+
)
49+
50+
51+
def alloca(
52+
sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None
53+
):
54+
if loc is None:
55+
loc = get_user_code_loc()
56+
return maybe_cast(
57+
get_result_or_results(
58+
_alloc(memref.AllocaOp, sizes, element_type, loc=loc, ip=ip)
59+
)
60+
)
61+
62+
63+
def load(mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None):
64+
if loc is None:
65+
loc = get_user_code_loc()
66+
indices = list(indices)
67+
for idx, i in enumerate(indices):
68+
if isinstance(i, int):
69+
indices[idx] = constant(i, index=True)
70+
return maybe_cast(
71+
get_result_or_results(memref.LoadOp.__base__(mem, indices, loc=loc, ip=ip))
72+
)
73+
74+
75+
def store(
76+
value: Value, mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None
77+
):
78+
if loc is None:
79+
loc = get_user_code_loc()
80+
indices = list(indices)
81+
for idx, i in enumerate(indices):
82+
if isinstance(i, int):
83+
indices[idx] = constant(i, index=True)
84+
return maybe_cast(
85+
get_result_or_results(memref.StoreOp(value, mem, indices, loc=loc, ip=ip))
86+
)
87+
88+
89+
def subview(
90+
source: "MemRef",
91+
static_offsets: Optional[Sequence[int]] = None,
92+
static_sizes: Optional[Sequence[int]] = None,
93+
static_strides: Optional[Sequence[int]] = None,
94+
*,
95+
loc=None,
96+
ip=None,
97+
):
98+
if loc is None:
99+
loc = get_user_code_loc()
100+
assert static_sizes, f"this convenience method only handles static sizes"
101+
offsets = sizes = strides = []
102+
result = T.memref(*static_sizes, source.dtype)
103+
val = memref.subview(
104+
result,
105+
source,
106+
offsets,
107+
sizes,
108+
strides,
109+
static_offsets,
110+
static_sizes,
111+
static_strides,
112+
loc=loc,
113+
ip=ip,
114+
)
115+
# dumbest hack ever - the default builder doesn't connect to inferReturnTypes
116+
# but the diag message does
117+
try:
118+
val.owner.verify()
119+
return val
120+
except MLIRError as e:
121+
diag = str(e.error_diagnostics[0])
122+
correct_type = re.findall(r"'memref<(.*)>'", diag)
123+
assert len(correct_type) == 1
124+
correct_type = Type.parse(f"memref<{correct_type[0]}>")
125+
val.owner.erase()
126+
return memref.subview(
127+
correct_type,
128+
source,
129+
offsets,
130+
sizes,
131+
strides,
132+
static_offsets,
133+
static_sizes,
134+
static_strides,
135+
loc=loc,
136+
ip=ip,
137+
)
138+
139+
140+
@register_value_caster(MemRefType.static_typeid)
141+
class MemRef(Value):
142+
def __str__(self):
143+
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
144+
145+
def __repr__(self):
146+
return str(self)
147+
148+
@staticmethod
149+
def isinstance(other: Value):
150+
return isinstance(other, Value) and MemRefType.isinstance(other.type)
151+
152+
@cached_property
153+
def _shaped_type(self) -> ShapedType:
154+
return ShapedType(self.type)
155+
156+
def has_static_shape(self) -> bool:
157+
return self._shaped_type.has_static_shape
158+
159+
def has_rank(self) -> bool:
160+
return self._shaped_type.has_rank
161+
162+
@cached_property
163+
def shape(self) -> Tuple[int, ...]:
164+
return tuple(self._shaped_type.shape)
165+
166+
@cached_property
167+
def dtype(self) -> Type:
168+
return self._shaped_type.element_type
169+
170+
def __getitem__(self, idx: tuple) -> "MemRef":
171+
loc = get_user_code_loc()
172+
173+
if not self.has_rank():
174+
raise ValueError("only ranked memref slicing/indexing supported")
175+
176+
if idx == Ellipsis or idx == slice(None):
177+
return self
178+
if isinstance(idx, tuple) and all(i == slice(None) for i in idx):
179+
return self
180+
if idx is None:
181+
return expand_shape(self, (0,), loc=loc)
182+
183+
idx = list((idx,) if isinstance(idx, int) else idx)
184+
for i, d in enumerate(idx):
185+
if isinstance(d, int):
186+
idx[i] = constant(d, index=True, loc=loc)
187+
188+
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
189+
return load(self, idx, loc=loc)
190+
else:
191+
return _subview(self, tuple(idx), loc=loc)
192+
193+
def __setitem__(self, idx, source):
194+
loc = get_user_code_loc()
195+
196+
if not self.has_rank():
197+
raise ValueError("only ranked memref slicing/indexing supported")
198+
199+
idx = list((idx,) if isinstance(idx, int) else idx)
200+
for i, d in enumerate(idx):
201+
if isinstance(d, int):
202+
idx[i] = constant(d, index=True, loc=loc)
203+
204+
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
205+
assert isinstance(
206+
source, Scalar
207+
), "coordinate insert requires scalar element"
208+
store(source, self, idx, loc=loc)
209+
else:
210+
_copy_to_subview(self, source, tuple(idx), loc=loc)
211+
212+
213+
def expand_shape(
214+
inp,
215+
newaxis_dims,
216+
*,
217+
loc=None,
218+
ip=None,
219+
) -> MemRef:
220+
"""Expand the shape of a memref.
221+
222+
Insert a new axis that will appear at the `axis` position in the expanded
223+
memref shape.
224+
225+
Args:
226+
inp: Input memref-like.
227+
axis: Position in the expanded axes where the new axis (or axes) is placed.
228+
229+
Returns:
230+
View of `a` with the number of dimensions increased.
231+
232+
"""
233+
if loc is None:
234+
loc = get_user_code_loc()
235+
236+
if len(newaxis_dims) == 0:
237+
return inp
238+
239+
result_shape, reassoc_list = compute_result_shape_reassoc_list(
240+
inp.shape, newaxis_dims
241+
)
242+
243+
return MemRef(
244+
memref.expand_shape(
245+
T.memref(*result_shape, inp.dtype), inp, reassoc_list, loc=loc, ip=ip
246+
)
247+
)
248+
249+
250+
def _subview(
251+
mem: MemRef,
252+
idx,
253+
*,
254+
loc=None,
255+
ip=None,
256+
) -> MemRef:
257+
if loc is None:
258+
loc = get_user_code_loc()
259+
260+
indexer = _indices_to_indexer(idx, mem.shape)
261+
out = mem
262+
263+
if indexer.is_constant():
264+
out = subview(
265+
out,
266+
static_offsets=indexer.static_offsets(),
267+
static_sizes=indexer.static_sizes(),
268+
static_strides=indexer.static_strides(),
269+
loc=loc,
270+
ip=ip,
271+
)
272+
else:
273+
raise ValueError(f"non-constant indices not supported {indexer}")
274+
275+
# This adds newaxis/None dimensions.
276+
return expand_shape(out, indexer.newaxis_dims, loc=loc, ip=ip)
277+
278+
279+
def _copy_to_subview(
280+
dest: MemRef,
281+
source: MemRef,
282+
idx,
283+
*,
284+
loc=None,
285+
ip=None,
286+
):
287+
if loc is None:
288+
loc = get_user_code_loc()
289+
if isinstance(source, Scalar):
290+
source = expand_shape(source, (0,), loc=loc, ip=ip)
291+
292+
dest_subview = _subview(dest, idx, loc=loc, ip=ip)
293+
assert (
294+
dest_subview.shape == source.shape
295+
), f"Expected matching shape for dest subview {dest_subview.shape} and source {source.shape=}"
296+
297+
return memref.copy(source, dest_subview, loc=loc, ip=ip)

0 commit comments

Comments
 (0)