Skip to content

Commit d209fe0

Browse files
authored
add more amdgpu examples (#133)
1 parent 85e598c commit d209fe0

File tree

9 files changed

+837
-432
lines changed

9 files changed

+837
-432
lines changed

mlir/extras/ast/canonicalize.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from dis import findlinestarts
99
from opcode import opmap
10-
from typing import List, Union, Sequence
10+
from typing import List, Union, Sequence, get_type_hints
1111

1212
import astunparse
1313
from bytecode import ConcreteBytecode
@@ -29,6 +29,43 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
2929
return node
3030

3131

32+
# https://stackoverflow.com/a/66582895/9045206
33+
class AnnotationsCollector(ast.NodeVisitor):
34+
def __init__(self):
35+
self.annotations = {}
36+
37+
def visit_AnnAssign(self, node):
38+
if node.simple:
39+
# 'simple' == a single name, not an attribute or subscription.
40+
# we can therefore count on `node.target.id` to exist. This is
41+
# the same criteria used for module and class-level variable
42+
# annotations.
43+
self.annotations[node.target.id] = node.annotation
44+
45+
46+
def function_local_annotations(func):
47+
"""Return a mapping of name to string annotations for function locals
48+
49+
Python does not retain PEP 526 "variable: annotation" variable annotations
50+
within a function body, as local variables do not have a lifetime beyond
51+
the local namespace. This function extracts the mapping from functions that
52+
have source code available.
53+
54+
"""
55+
source = inspect.getsource(func)
56+
mod = ast.parse(source)
57+
assert mod.body and isinstance(mod.body[0], (ast.FunctionDef, ast.AsyncFunctionDef))
58+
collector = AnnotationsCollector()
59+
collector.visit(mod.body[0])
60+
func.__annotations__.update(
61+
{
62+
name: ast.get_source_segment(source, node)
63+
for name, node in collector.annotations.items()
64+
}
65+
)
66+
return get_type_hints(func)
67+
68+
3269
def transform_func(f, *transformer_ctors: type(Transformer)):
3370
module = get_module_cst(f)
3471
context = types.SimpleNamespace()

mlir/extras/dialects/ext/_shaped_value.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from dataclasses import dataclass
12
from functools import cached_property, reduce
2-
from typing import Tuple
3-
43
import numpy as np
4+
from typing import Tuple, Union, List, Any
55

6-
from ....ir import DenseElementsAttr, ShapedType, Type
6+
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
7+
from .arith import Scalar
8+
from ....ir import DenseElementsAttr, ShapedType, Type, Value, RankedTensorType
79

810
S = ShapedType.get_dynamic_size()
911

@@ -61,3 +63,262 @@ def dtype(self) -> Type:
6163
cls.dtype.__set_name__(None, "dtype")
6264

6365
return cls
66+
67+
68+
@dataclass(frozen=True)
69+
class _Indexer:
70+
indices: Tuple[Union[int, Scalar, slice, "Ellipsis", None]]
71+
newaxis_dims: Tuple[int, "Ellipsis"]
72+
in_shape: Tuple[Union[Value, int]]
73+
74+
def is_constant(self):
75+
return all(_is_constant_index(i) for i in self.indices)
76+
77+
def is_full(self):
78+
return all(
79+
isinstance(idx, slice)
80+
# TODO(max): could also work for constant Scalar
81+
and all([isinstance(x, int) for x in [idx.start, idx.stop, idx.step]])
82+
and len(range(*idx.indices(self.in_shape[i]))) == self.in_shape[i]
83+
for i, idx in enumerate(self.indices)
84+
)
85+
86+
# waiting on hashable slices in 3.12 https://stackoverflow.com/a/76562346
87+
# @lru_cache(maxsize=1)
88+
def static_offsets(self):
89+
offsets = []
90+
for i in self.indices:
91+
if isinstance(i, (int, Scalar)):
92+
offsets.append(int(i))
93+
elif isinstance(i, slice):
94+
offsets.append(int(i.start))
95+
else:
96+
raise ValueError(f"idx {i} not supported with static offsets")
97+
return tuple(offsets)
98+
99+
# @lru_cache(maxsize=1)
100+
def static_sizes(self):
101+
sizes = []
102+
for i in self.indices:
103+
if isinstance(i, (int, Scalar)):
104+
sizes.append(1)
105+
elif isinstance(i, slice):
106+
start, stop, step = map(int, (i.start, i.stop, i.step))
107+
if all(isinstance(j, int) for j in (start, stop, step)):
108+
s = ((stop - start) // step) + 1
109+
if (stop - start) % step == 0:
110+
s -= 1
111+
sizes.append(s)
112+
else:
113+
raise ValueError(f"idx {i} not supported with static sizes")
114+
115+
else:
116+
raise ValueError(f"idx {i} not supported with static sizes")
117+
return tuple(sizes)
118+
119+
# @lru_cache(maxsize=1)
120+
def static_strides(self):
121+
strides = []
122+
for i in self.indices:
123+
if isinstance(i, (int, Scalar)):
124+
strides.append(1)
125+
elif isinstance(i, slice):
126+
strides.append(int(i.step))
127+
else:
128+
raise ValueError(f"idx {i} not supported with static strides")
129+
return tuple(strides)
130+
131+
132+
def _indices_to_indexer(
133+
idx: Tuple[Union[Scalar, slice, "Ellipsis", None]], in_shape: Tuple[int]
134+
) -> _Indexer:
135+
"""Processes sequence of index objects and constructs _Indexer with
136+
corresponding indexing tensor and collapse dims (i.e., scatter/gather dims).
137+
138+
Args:
139+
idx: Sequence (list or tuple) of slices, ellipses, Scalar, or Tensors.
140+
in_shape: The shape of the tensor being indexed into.
141+
142+
Returns:
143+
_Indexer object.
144+
145+
"""
146+
idx = _canonicalize_tuple_index(idx, len(in_shape))
147+
148+
in_axis = 0 # Current axis in input.
149+
out_axis = 0 # Current axis in output.
150+
indices: List[Union[Scalar, slice, Ellipsis, None]] = [slice(None)] * len(in_shape)
151+
newaxis_dims: List[int] = []
152+
153+
# nb: idx_e <-> idx_element
154+
for idx_i, idx_e in enumerate(idx):
155+
if _is_scalar(idx_e) and _has_index_type(idx_e):
156+
# Handle basic Scalar indexes.
157+
indices[in_axis] = idx_e
158+
in_axis += 1
159+
# Handle newaxis (None)
160+
elif idx_e is None:
161+
newaxis_dims.append(out_axis)
162+
out_axis += 1
163+
elif isinstance(idx_e, slice):
164+
# Normalize the slice to use None when possible
165+
start, stop, step = idx_e.start, idx_e.stop, idx_e.step
166+
if step is None or isinstance(step, int) and step == 1:
167+
step = None
168+
if step is None:
169+
if start is None or isinstance(start, int) and start == 0:
170+
start = None
171+
if (
172+
stop is None
173+
or isinstance(stop, int)
174+
and in_shape[in_axis] != ShapedType.get_dynamic_size()
175+
and stop >= in_shape[in_axis]
176+
):
177+
stop = None
178+
# Handle slice(None) and slice(None, None, -1)
179+
if (
180+
start is None
181+
and stop is None
182+
and (step is None or isinstance(step, int) and step == -1)
183+
):
184+
if step == -1:
185+
raise IndexError(
186+
f"Negative step indexing mode not yet supported:\n{idx}"
187+
)
188+
indices[in_axis] = slice(None)
189+
out_axis += 1
190+
in_axis += 1
191+
192+
# Handle slice index (only static shape supported)
193+
else:
194+
if (
195+
not isinstance(in_shape[in_axis], int)
196+
or in_shape[in_axis] == ShapedType.get_dynamic_size()
197+
):
198+
msg = (
199+
"Cannot use NumPy slice indexing on an array dimension whose "
200+
f"size is not statically known ({in_shape[in_axis]}). "
201+
)
202+
raise IndexError(msg)
203+
204+
if step is None:
205+
step = 1
206+
indices[in_axis] = slice(start, stop, step)
207+
208+
out_axis += 1
209+
in_axis += 1
210+
else:
211+
raise IndexError(f"Indexing mode not yet supported:\n{idx}")
212+
213+
for i, idx in enumerate(indices):
214+
if _is_constant_index(idx) and _is_constant_scalar(in_shape[i]):
215+
if isinstance(idx, slice):
216+
indices[i] = slice(*idx.indices(int(in_shape[i])))
217+
elif isinstance(idx, Scalar):
218+
indices[i] = int(idx)
219+
220+
return _Indexer(
221+
newaxis_dims=tuple(newaxis_dims), indices=tuple(indices), in_shape=in_shape
222+
)
223+
224+
225+
def _canonicalize_tuple_index(idx: Tuple[Any], rank: int):
226+
"""Helper to
227+
1. remove Ellipsis and replace with implicit trailing slice(None)s.
228+
2. cast Python lists of lists or numpy arrays to index Tensors
229+
230+
Args:
231+
rank: Rank of tensor.
232+
idx: Index object (Scalar, Tensor, slice, Ellipse, or None).
233+
234+
Returns:
235+
Tuple of index objects with no ellipses.
236+
"""
237+
238+
len_without_none = 0
239+
for e in idx:
240+
if e is None or e is Ellipsis:
241+
continue
242+
else:
243+
len_without_none += 1
244+
245+
if len_without_none > rank:
246+
raise IndexError(
247+
f"Too many indices for shaped type with rank: {len_without_none} "
248+
f"non-None/Ellipsis indices for dim {rank}."
249+
)
250+
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
251+
ellipsis_index = next(ellipses, None)
252+
if ellipsis_index is not None:
253+
if next(ellipses, None) is not None:
254+
raise IndexError(
255+
f"Multiple ellipses (...) not supported: {list(map(type, idx))}."
256+
)
257+
colons = (slice(None),) * (rank - len_without_none)
258+
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1 :]
259+
elif len_without_none < rank:
260+
colons = (slice(None),) * (rank - len_without_none)
261+
idx = tuple(idx) + colons
262+
return idx
263+
264+
265+
def _is_int_arraylike(x):
266+
"""Returns True if x is array-like with integer dtype, False otherwise.
267+
268+
Positive (i.e., return True) examples are e.g., [[0], [1]], [[0, 1]],
269+
[[[0, 1]], [[0, 1]]].
270+
"""
271+
return (
272+
isinstance(x, int)
273+
and not isinstance(x, bool)
274+
or isinstance(x, (list, tuple))
275+
and all(_is_int_arraylike(e) for e in x)
276+
)
277+
278+
279+
def _is_scalar(e: Any) -> bool:
280+
"""Checks whether e is a Scalar or can be used to construct a Scalar.
281+
282+
Args:
283+
e: Anything
284+
"""
285+
return isinstance(e, Scalar) or isinstance(e, (int, float, bool))
286+
287+
288+
def _has_index_type(e: Any) -> bool:
289+
"""Checks whether e has MLIR index type or a Python value that can be used
290+
to construct an index type.
291+
292+
Args:
293+
e: Anything
294+
"""
295+
return (
296+
isinstance(e, int)
297+
or isinstance(e, np.ndarray)
298+
and e.dtype in {np.intp}
299+
or isinstance(e, Value)
300+
and _is_index_type(e.type)
301+
or isinstance(e.type, RankedTensorType)
302+
and _is_index_type(e.type.element_type)
303+
)
304+
305+
306+
def _is_constant_index(e: Any) -> bool:
307+
return (
308+
isinstance(e, Scalar)
309+
and e.is_constant()
310+
or isinstance(e, (int, float, bool))
311+
or isinstance(e, slice)
312+
and _is_constant_scalar(e.start)
313+
and _is_constant_scalar(e.stop)
314+
and _is_constant_scalar(e.step)
315+
)
316+
317+
318+
def _is_constant_scalar(e: Any) -> bool:
319+
return (
320+
isinstance(e, Scalar)
321+
and e.is_constant()
322+
or (isinstance(e, (int, float, bool)) and e != ShapedType.get_dynamic_size())
323+
or e is None
324+
)

mlir/extras/dialects/ext/memref.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import numpy as np
55

6-
from ._shaped_value import ShapedValue
6+
from ._shaped_value import ShapedValue, _indices_to_indexer
77
from .arith import Scalar, constant
8-
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
8+
from .tensor import compute_result_shape_reassoc_list
99
from .vector import Vector
1010
from ... import types as T
1111
from ...meta import region_op
@@ -15,7 +15,7 @@
1515
infer_mlir_type,
1616
)
1717
from ...._mlir_libs._mlir import register_value_caster
18-
from ....dialects import memref, arith
18+
from ....dialects import memref, arith, vector
1919
from ....dialects._ods_common import get_op_result_or_op_results
2020
from ....dialects.memref import *
2121
from ....ir import (
@@ -191,7 +191,15 @@ def __setitem__(self, idx, val):
191191
assert isinstance(
192192
val, (Scalar, Vector)
193193
), "coordinate insert requires scalar element"
194-
store(val, self, idx, loc=loc)
194+
if isinstance(val, Scalar):
195+
store(val, self, idx, loc=loc)
196+
elif isinstance(val, Vector):
197+
return vector.StoreOp(
198+
valueToStore=val,
199+
base=self,
200+
indices=idx,
201+
loc=loc,
202+
)
195203
else:
196204
_copy_to_subview(self, val, tuple(idx), loc=loc)
197205

0 commit comments

Comments
 (0)