Skip to content

Commit ce2f5b0

Browse files
Add API support for emitting raw tensor slice, update, and reshape. (#59)
1 parent 54bb2a8 commit ce2f5b0

File tree

4 files changed

+280
-13
lines changed

4 files changed

+280
-13
lines changed

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,19 @@ def build_index_attribute(value: int) -> IntegerAttr:
264264
return IntegerAttr.get(IndexType.get(), value)
265265

266266

267-
def build_index_value(value: int) -> Value:
268-
return arith_d.ConstantOp(IndexType.get(), value).result
269-
270-
271-
def build_tensor_dim_value(t: Value, dim: int) -> Value:
272-
return tensor_d.DimOp(t, build_index_value(dim)).result
267+
def build_index_value(
268+
value: int, constant_cache: Optional[Dict[int, Value]] = None
269+
) -> Value:
270+
if constant_cache is not None and value in constant_cache:
271+
return constant_cache[value]
272+
index_value = arith_d.ConstantOp(IndexType.get(), value).result
273+
if constant_cache is not None:
274+
constant_cache[value] = index_value
275+
return index_value
276+
277+
278+
def build_tensor_dim_value(
279+
t: Value, dim: int, constant_cache: Optional[Dict[int, Value]] = None
280+
) -> Value:
281+
dim_value = build_index_value(dim, constant_cache=constant_cache)
282+
return tensor_d.DimOp(t, dim_value).result

python/shark_turbine/aot/support/procedural/iree_emitter.py

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
"""Python API for IREE's high-level tensor dialects."""
88

9-
from typing import Any, List, Sequence, Tuple, Union
9+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1010

1111
import functools
1212

@@ -36,7 +36,11 @@
3636
BuildableScalarValue = Union[IrValueScalar, Value]
3737
BuildableTensorDimDecl = Union[int, Value]
3838
BuildableTensorType = IrValueTensor
39-
BuildableIndexType = Union[Value, int]
39+
BuildableIndexType = Union[BuildableScalarValue, int]
40+
BuildableIndexLengthType = Union[
41+
BuildableTensorDimDecl, Tuple[BuildableTensorDimDecl, BuildableTensorDimDecl]
42+
]
43+
BuildableSliceType = Sequence[BuildableIndexLengthType]
4044
StaticIndexType = int
4145

4246

@@ -52,9 +56,12 @@ def cast_tensor_value(x: BuildableTensorType) -> IrValueTensor:
5256
return x
5357

5458

55-
def cast_index_value(x: BuildableIndexType) -> Value:
59+
def cast_index_value(
60+
x: BuildableIndexType, *, constant_cache: Optional[Dict[int, Value]] = None
61+
) -> Value:
62+
x = unwrap_intrinsic_value(x)
5663
if isinstance(x, int):
57-
return build_index_value(x)
64+
return build_index_value(x, constant_cache=constant_cache)
5865
else:
5966
return x
6067

@@ -144,6 +151,137 @@ def tensor_empty(
144151
result.set_dynamic_dim_values(dyn_dim_values)
145152
return result
146153

154+
@emitter
155+
def tensor_reshape(
156+
self, source: BuildableTensorType, *result_dims: BuildableTensorDimDecl
157+
) -> "IrValueTensor":
158+
constant_cache: Dict[int, Value] = {}
159+
source = cast_tensor_value(source)
160+
result_dim_decls, result_dynamic_dims = cast_tensor_dim_decl(result_dims)
161+
result_type = RankedTensorType.get(
162+
result_dim_decls, source.ir_type.element_type
163+
)
164+
result_value = flow_d.TensorReshapeOp(
165+
result_type,
166+
source.ir_value,
167+
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
168+
result_dynamic_dims,
169+
).result
170+
result = IrValueTensor(result_value, dtype=source.dtype)
171+
result.set_dynamic_dim_values(result_dynamic_dims)
172+
return result
173+
174+
@emitter
175+
def tensor_slice(
176+
self, source: BuildableTensorType, *indices: BuildableSliceType
177+
) -> "IrValueTensor":
178+
"""Extracts a slice of a tensor.
179+
180+
The given indices must match the rank of the source and each index is
181+
interpreted as `(start_index[, length])`, where the `length` is taken
182+
to be 1 if only a single value is given for an index.
183+
"""
184+
source = cast_tensor_value(source)
185+
source_value = source.ir_value
186+
rank = source.rank
187+
if len(indices) != rank:
188+
raise ValueError(
189+
f"Slice indices must match the source rank. Got {len(indices)}, expected {rank}"
190+
)
191+
# Unpack start_indices and lengths.
192+
start_indices: List[BuildableIndexType] = []
193+
lengths: List[BuildableIndexType] = []
194+
for index_pack in indices:
195+
if isinstance(index_pack, (tuple, list)):
196+
if len(index_pack) == 2:
197+
start_indices.append(index_pack[0])
198+
lengths.append(index_pack[1])
199+
continue
200+
else:
201+
start_indices.append(index_pack)
202+
lengths.append(1)
203+
continue
204+
raise ValueError(
205+
f"Slice indices expected to be a single value or a 2-tuple. Got {index_pack}"
206+
)
207+
208+
# Process the lengths into a result shape and input length.
209+
index_value_cache: Dict[int, Value] = {}
210+
length_values: List[Value] = []
211+
result_shape: List[int] = []
212+
result_dynamic_dims: List[Value] = []
213+
for raw_length in lengths:
214+
if isinstance(raw_length, int):
215+
# Static.
216+
result_shape.append(raw_length)
217+
if raw_length in index_value_cache:
218+
# Cached.
219+
length_values.append(index_value_cache[raw_length])
220+
else:
221+
# Not cached.
222+
length_value = cast_index_value(raw_length)
223+
index_value_cache[raw_length] = length_value
224+
length_values.append(length_value)
225+
else:
226+
# Dynamic.
227+
result_shape.append(ShapedTypeDynamicSizeSentinel)
228+
length_value = cast_index_value(raw_length)
229+
length_values.append(length_value)
230+
result_dynamic_dims.append(length_value)
231+
assert len(length_values) == rank
232+
assert result_shape.count(ShapedTypeDynamicSizeSentinel) == len(
233+
result_dynamic_dims
234+
)
235+
236+
# Process start indices.
237+
start_index_values = [cast_index_value(idx) for idx in start_indices]
238+
# Emit.
239+
result_type = RankedTensorType.get(result_shape, source.ir_type.element_type)
240+
constant_cache: Dict[int, Value] = {}
241+
result_value = flow_d.TensorSliceOp(
242+
result_type,
243+
source_value,
244+
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
245+
start_index_values,
246+
length_values,
247+
result_dynamic_dims,
248+
).result
249+
result = IrValueTensor(result_value, dtype=source.dtype)
250+
result.set_dynamic_dim_values(result_dynamic_dims)
251+
return result
252+
253+
@emitter
254+
def tensor_update(
255+
self,
256+
target: BuildableTensorType,
257+
update: BuildableTensorType,
258+
*start_indices: BuildableIndexType,
259+
) -> "IrValueTensor":
260+
"""Applies an update to a target at start_indices and returns the mutated target."""
261+
constant_cache: Dict[int, Value] = {}
262+
target = cast_tensor_value(target)
263+
target_dynamic_dims = target.get_only_dynamic_dim_values(
264+
constant_cache=constant_cache
265+
)
266+
update = cast_tensor_value(update)
267+
update_dynamic_dims = update.get_only_dynamic_dim_values(
268+
constant_cache=constant_cache
269+
)
270+
start_index_dim_values = [
271+
cast_index_value(idx, constant_cache=constant_cache)
272+
for idx in start_indices
273+
]
274+
result_value = flow_d.TensorUpdateOp(
275+
target.ir_value,
276+
target_dynamic_dims,
277+
start_index_dim_values,
278+
update.ir_value,
279+
update_dynamic_dims,
280+
).result
281+
result = IrValueTensor(result_value, target.dtype)
282+
result.set_dynamic_dim_values(target_dynamic_dims)
283+
return result
284+
147285
@emitter
148286
def tensor_splat(
149287
self,

python/shark_turbine/aot/support/procedural/primitives.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
# operate on instances of these.
1010

1111
from typing import (
12+
Dict,
1213
List,
1314
Optional,
1415
Sequence,
16+
Tuple,
1517
Union,
1618
)
1719

@@ -161,7 +163,12 @@ def __repr__(self):
161163
def resolve_ir_values(self, proc_trace: IrTrace) -> Sequence[Value]:
162164
return (self.ir_value,)
163165

164-
def get_dim_value(self, index: int) -> Value:
166+
def get_dim_value(
167+
self,
168+
index: int,
169+
*,
170+
constant_cache: Optional[Dict[int, Value]] = None,
171+
) -> Value:
165172
"""Gets a dimension as an Index value.
166173
167174
Requires that an InsertionPoint and Location are on the context stack.
@@ -178,9 +185,21 @@ def get_dim_value(self, index: int) -> Value:
178185
# TODO: Add MLIR API support for creating an insertion point after
179186
# an operation and use that to set the InsertionPoint to the
180187
# earliest point.
181-
dim_value = build_tensor_dim_value(self.ir_value, index)
188+
dim_value = build_tensor_dim_value(
189+
self.ir_value, index, constant_cache=constant_cache
190+
)
182191
self._cached_dim_values[index] = dim_value
183192
return dim_value
184193
else:
185194
# Dynamic dim is known.
186195
return dynamic_dim
196+
197+
def get_only_dynamic_dim_values(
198+
self, *, constant_cache: Optional[Dict[int, Value]] = None
199+
) -> List[Value]:
200+
"""Returns a list of *only* the dynamic dim Values."""
201+
values: List[Value] = []
202+
for i, sentinel in enumerate(self._dynamic_dims):
203+
if sentinel is not Empty:
204+
values.append(self.get_dim_value(i, constant_cache=constant_cache))
205+
return values

tests/aot/iree_procedural_test.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,110 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):
8282
inst = BasicModule(context=Context())
8383
module_str = str(CompiledModule.get_mlir_module(inst))
8484
print(module_str)
85-
self.assertIn("util.global private mutable @_x.global {noinline} : tensor<?x34xf32>", module_str)
85+
self.assertIn(
86+
"util.global private mutable @_x.global {noinline} : tensor<?x34xf32>",
87+
module_str,
88+
)
8689
self.assertIn("%0 = flow.tensor.splat", module_str)
8790
self.assertIn("util.global.store %0, @_x.global : tensor<?x34xf32>", module_str)
8891

92+
def testTensorSliceStatic(self):
93+
class BasicModule(CompiledModule):
94+
def foobar(self, x=AbstractTensor(3, 4)):
95+
return IREE.tensor_slice(x, 0, (1, 3))
96+
97+
inst = BasicModule(context=Context())
98+
module_str = str(CompiledModule.get_mlir_module(inst))
99+
print(module_str)
100+
self.assertIn(
101+
"flow.tensor.slice %arg0[%c0, %c1_0 for %c1, %c3] : tensor<3x4xf32> -> tensor<1x3xf32>",
102+
module_str,
103+
)
104+
105+
def testTensorSliceDynamicIndex(self):
106+
class SliceDynamicIndex(CompiledModule):
107+
def foobar(self, x=AbstractIndex):
108+
empty = IREE.tensor_empty(x, 16)
109+
return IREE.tensor_slice(empty, x, 4)
110+
111+
inst = SliceDynamicIndex(context=Context())
112+
module_str = str(CompiledModule.get_mlir_module(inst))
113+
print(module_str)
114+
self.assertIn(
115+
"flow.tensor.slice %0[%arg0, %c4 for %c1, %c1] : tensor<?x16xf32>{%arg0} -> tensor<1x1xf32>",
116+
module_str,
117+
)
118+
119+
def testTensorSliceDynamicLength(self):
120+
class SliceDynamicIndex(CompiledModule):
121+
def foobar(self, x=AbstractIndex, y=AbstractIndex):
122+
empty = IREE.tensor_empty(x, 16)
123+
return IREE.tensor_slice(empty, (x, y), 4)
124+
125+
inst = SliceDynamicIndex(context=Context())
126+
module_str = str(CompiledModule.get_mlir_module(inst))
127+
print(module_str)
128+
self.assertIn(
129+
"flow.tensor.slice %0[%arg0, %c4 for %arg1, %c1] : tensor<?x16xf32>{%arg0} -> tensor<?x1xf32>{%arg1}",
130+
module_str,
131+
)
132+
133+
def testTensorUpdateStatic(self):
134+
class UpdateStatic(CompiledModule):
135+
def foobar(
136+
self,
137+
target=AbstractTensor(4, 4),
138+
update=AbstractTensor(2, 2),
139+
i=AbstractIndex,
140+
j=AbstractIndex,
141+
):
142+
return IREE.tensor_update(target, update, i, j)
143+
144+
inst = UpdateStatic(context=Context())
145+
module_str = str(CompiledModule.get_mlir_module(inst))
146+
print(module_str)
147+
self.assertIn(
148+
"flow.tensor.update %arg1, %arg0[%arg2, %arg3] : tensor<2x2xf32> -> %arg0 as tensor<4x4xf32>",
149+
module_str,
150+
)
151+
152+
def testTensorUpdateDynamic(self):
153+
class UpdateDynamic(CompiledModule):
154+
def foobar(
155+
self,
156+
x=AbstractIndex,
157+
y=AbstractIndex,
158+
i=AbstractIndex,
159+
j=AbstractIndex,
160+
value=AbstractF32,
161+
):
162+
target = IREE.tensor_empty(x, y)
163+
update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32)
164+
return IREE.tensor_update(target, update, 2, 2)
165+
166+
inst = UpdateDynamic(context=Context())
167+
module_str = str(CompiledModule.get_mlir_module(inst))
168+
print(module_str)
169+
self.assertIn(
170+
"flow.tensor.update %1, %0[%c2, %c2] : tensor<?x?xf32>{%arg2, %arg3} -> %0 as tensor<?x?xf32>{%arg0, %arg1}",
171+
module_str,
172+
)
173+
174+
def testTensorReshape(self):
175+
class ReshapeModule(CompiledModule):
176+
def foobar(self, x=AbstractIndex, y=AbstractIndex):
177+
empty = IREE.tensor_empty(x, 16)
178+
reshaped = IREE.tensor_reshape(empty, 1, y, y)
179+
return reshaped
180+
181+
inst = ReshapeModule(context=Context())
182+
module_str = str(CompiledModule.get_mlir_module(inst))
183+
print(module_str)
184+
self.assertIn(
185+
"flow.tensor.reshape %0 : tensor<?x16xf32>{%arg0} -> tensor<1x?x?xf32>{%arg1, %arg1}",
186+
module_str,
187+
)
188+
89189

90190
if __name__ == "__main__":
91191
logging.basicConfig(level=logging.DEBUG)

0 commit comments

Comments
 (0)