66
77"""Python API for IREE's high-level tensor dialects."""
88
9- from typing import Any , List , Sequence , Tuple , Union
9+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
1010
1111import functools
1212
3636BuildableScalarValue = Union [IrValueScalar , Value ]
3737BuildableTensorDimDecl = Union [int , Value ]
3838BuildableTensorType = IrValueTensor
39- BuildableIndexType = Union [Value , int ]
39+ BuildableIndexType = Union [BuildableScalarValue , int ]
40+ BuildableIndexLengthType = Union [
41+ BuildableTensorDimDecl , Tuple [BuildableTensorDimDecl , BuildableTensorDimDecl ]
42+ ]
43+ BuildableSliceType = Sequence [BuildableIndexLengthType ]
4044StaticIndexType = int
4145
4246
@@ -52,9 +56,12 @@ def cast_tensor_value(x: BuildableTensorType) -> IrValueTensor:
5256 return x
5357
5458
55- def cast_index_value (x : BuildableIndexType ) -> Value :
59+ def cast_index_value (
60+ x : BuildableIndexType , * , constant_cache : Optional [Dict [int , Value ]] = None
61+ ) -> Value :
62+ x = unwrap_intrinsic_value (x )
5663 if isinstance (x , int ):
57- return build_index_value (x )
64+ return build_index_value (x , constant_cache = constant_cache )
5865 else :
5966 return x
6067
@@ -144,6 +151,137 @@ def tensor_empty(
144151 result .set_dynamic_dim_values (dyn_dim_values )
145152 return result
146153
154+ @emitter
155+ def tensor_reshape (
156+ self , source : BuildableTensorType , * result_dims : BuildableTensorDimDecl
157+ ) -> "IrValueTensor" :
158+ constant_cache : Dict [int , Value ] = {}
159+ source = cast_tensor_value (source )
160+ result_dim_decls , result_dynamic_dims = cast_tensor_dim_decl (result_dims )
161+ result_type = RankedTensorType .get (
162+ result_dim_decls , source .ir_type .element_type
163+ )
164+ result_value = flow_d .TensorReshapeOp (
165+ result_type ,
166+ source .ir_value ,
167+ source .get_only_dynamic_dim_values (constant_cache = constant_cache ),
168+ result_dynamic_dims ,
169+ ).result
170+ result = IrValueTensor (result_value , dtype = source .dtype )
171+ result .set_dynamic_dim_values (result_dynamic_dims )
172+ return result
173+
174+ @emitter
175+ def tensor_slice (
176+ self , source : BuildableTensorType , * indices : BuildableSliceType
177+ ) -> "IrValueTensor" :
178+ """Extracts a slice of a tensor.
179+
180+ The given indices must match the rank of the source and each index is
181+ interpreted as `(start_index[, length])`, where the `length` is taken
182+ to be 1 if only a single value is given for an index.
183+ """
184+ source = cast_tensor_value (source )
185+ source_value = source .ir_value
186+ rank = source .rank
187+ if len (indices ) != rank :
188+ raise ValueError (
189+ f"Slice indices must match the source rank. Got { len (indices )} , expected { rank } "
190+ )
191+ # Unpack start_indices and lengths.
192+ start_indices : List [BuildableIndexType ] = []
193+ lengths : List [BuildableIndexType ] = []
194+ for index_pack in indices :
195+ if isinstance (index_pack , (tuple , list )):
196+ if len (index_pack ) == 2 :
197+ start_indices .append (index_pack [0 ])
198+ lengths .append (index_pack [1 ])
199+ continue
200+ else :
201+ start_indices .append (index_pack )
202+ lengths .append (1 )
203+ continue
204+ raise ValueError (
205+ f"Slice indices expected to be a single value or a 2-tuple. Got { index_pack } "
206+ )
207+
208+ # Process the lengths into a result shape and input length.
209+ index_value_cache : Dict [int , Value ] = {}
210+ length_values : List [Value ] = []
211+ result_shape : List [int ] = []
212+ result_dynamic_dims : List [Value ] = []
213+ for raw_length in lengths :
214+ if isinstance (raw_length , int ):
215+ # Static.
216+ result_shape .append (raw_length )
217+ if raw_length in index_value_cache :
218+ # Cached.
219+ length_values .append (index_value_cache [raw_length ])
220+ else :
221+ # Not cached.
222+ length_value = cast_index_value (raw_length )
223+ index_value_cache [raw_length ] = length_value
224+ length_values .append (length_value )
225+ else :
226+ # Dynamic.
227+ result_shape .append (ShapedTypeDynamicSizeSentinel )
228+ length_value = cast_index_value (raw_length )
229+ length_values .append (length_value )
230+ result_dynamic_dims .append (length_value )
231+ assert len (length_values ) == rank
232+ assert result_shape .count (ShapedTypeDynamicSizeSentinel ) == len (
233+ result_dynamic_dims
234+ )
235+
236+ # Process start indices.
237+ start_index_values = [cast_index_value (idx ) for idx in start_indices ]
238+ # Emit.
239+ result_type = RankedTensorType .get (result_shape , source .ir_type .element_type )
240+ constant_cache : Dict [int , Value ] = {}
241+ result_value = flow_d .TensorSliceOp (
242+ result_type ,
243+ source_value ,
244+ source .get_only_dynamic_dim_values (constant_cache = constant_cache ),
245+ start_index_values ,
246+ length_values ,
247+ result_dynamic_dims ,
248+ ).result
249+ result = IrValueTensor (result_value , dtype = source .dtype )
250+ result .set_dynamic_dim_values (result_dynamic_dims )
251+ return result
252+
253+ @emitter
254+ def tensor_update (
255+ self ,
256+ target : BuildableTensorType ,
257+ update : BuildableTensorType ,
258+ * start_indices : BuildableIndexType ,
259+ ) -> "IrValueTensor" :
260+ """Applies an update to a target at start_indices and returns the mutated target."""
261+ constant_cache : Dict [int , Value ] = {}
262+ target = cast_tensor_value (target )
263+ target_dynamic_dims = target .get_only_dynamic_dim_values (
264+ constant_cache = constant_cache
265+ )
266+ update = cast_tensor_value (update )
267+ update_dynamic_dims = update .get_only_dynamic_dim_values (
268+ constant_cache = constant_cache
269+ )
270+ start_index_dim_values = [
271+ cast_index_value (idx , constant_cache = constant_cache )
272+ for idx in start_indices
273+ ]
274+ result_value = flow_d .TensorUpdateOp (
275+ target .ir_value ,
276+ target_dynamic_dims ,
277+ start_index_dim_values ,
278+ update .ir_value ,
279+ update_dynamic_dims ,
280+ ).result
281+ result = IrValueTensor (result_value , target .dtype )
282+ result .set_dynamic_dim_values (target_dynamic_dims )
283+ return result
284+
147285 @emitter
148286 def tensor_splat (
149287 self ,
0 commit comments