44
55from mlir .ir import Type , Value , MemRefType , ShapedType , MLIRError
66
7- from mlir_utils .dialects import memref
7+ import mlir_utils .types as T
8+ from mlir_utils .dialects import memref , arith
89from mlir_utils .dialects .ext .arith import Scalar , constant
910from mlir_utils .dialects .ext .tensor import (
1011 _indices_to_indexer ,
1112 compute_result_shape_reassoc_list ,
1213)
13- import mlir_utils .types as T
1414from mlir_utils .util import (
1515 register_value_caster ,
1616 get_user_code_loc ,
@@ -88,6 +88,8 @@ def store(
8888
8989def subview (
9090 source : "MemRef" ,
91+ offsets : Optional [Sequence [Value ]] = None ,
92+ strides : Optional [Sequence [Value ]] = None ,
9193 static_offsets : Optional [Sequence [int ]] = None ,
9294 static_sizes : Optional [Sequence [int ]] = None ,
9395 static_strides : Optional [Sequence [int ]] = None ,
@@ -97,11 +99,23 @@ def subview(
9799):
98100 if loc is None :
99101 loc = get_user_code_loc ()
102+ if offsets is None :
103+ offsets = []
104+ if static_offsets is None :
105+ static_offsets = []
106+ if strides is None :
107+ strides = []
108+ if static_strides is None :
109+ static_strides = []
100110 assert static_sizes , f"this convenience method only handles static sizes"
101- offsets = sizes = strides = []
102- result = T .memref (* static_sizes , source .dtype )
111+ sizes = []
112+ wrong_type = T .memref (* static_sizes , source .dtype )
113+ if offsets and static_offsets :
114+ assert all (s == S for s in static_offsets )
115+ if strides and static_strides :
116+ assert all (s == S for s in static_strides )
103117 val = memref .subview (
104- result ,
118+ wrong_type ,
105119 source ,
106120 offsets ,
107121 sizes ,
@@ -270,7 +284,51 @@ def _subview(
270284 ip = ip ,
271285 )
272286 else :
273- raise ValueError (f"non-constant indices not supported { indexer } " )
287+ # special tile case
288+ offsets = [None ] * len (indexer .in_shape )
289+ static_offsets = [None ] * len (indexer .in_shape )
290+ static_sizes = [None ] * len (indexer .in_shape )
291+ static_strides = [None ] * len (indexer .in_shape )
292+ for i , ind in enumerate (indexer .indices ):
293+ maybe_size = maybe_cast (ind .stop .owner .operands [1 ])
294+ if (
295+ isinstance (ind .start .owner .opview , arith .MulIOp )
296+ and isinstance (ind .stop .owner .opview , arith .MulIOp )
297+ and isinstance (ind .stop .owner .operands [0 ].owner .opview , arith .AddIOp )
298+ and ind .start .owner .operands [0 ]
299+ == ind .stop .owner .operands [0 ].owner .operands [0 ]
300+ and maybe_size .is_constant ()
301+ and isinstance (ind .step , int )
302+ or isinstance (ind .step , Scalar )
303+ and ind .step .is_constant ()
304+ ):
305+ offsets [i ] = ind .start
306+ static_offsets [i ] = S
307+ static_sizes [i ] = maybe_size .literal_value
308+ static_strides [i ] = (
309+ ind .step .literal_value if isinstance (ind .step , Scalar ) else ind .step
310+ )
311+ else :
312+ raise RuntimeError (f"indexing not supported { indexer .indices } " )
313+ offsets = list (filter (None , offsets ))
314+ static_offsets = list (filter (None , static_offsets ))
315+ static_sizes = list (filter (None , static_sizes ))
316+ static_strides = list (filter (None , static_strides ))
317+ assert (
318+ len (offsets )
319+ == len (static_sizes )
320+ == len (static_strides )
321+ == len (indexer .in_shape )
322+ ), f"not each slice is statically known: { indexer .indices } "
323+ out = subview (
324+ out ,
325+ offsets = offsets ,
326+ static_offsets = static_offsets ,
327+ static_sizes = static_sizes ,
328+ static_strides = static_strides ,
329+ loc = loc ,
330+ ip = ip ,
331+ )
274332
275333 # This adds newaxis/None dimensions.
276334 return expand_shape (out , indexer .newaxis_dims , loc = loc , ip = ip )
0 commit comments