@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
120120
121121// -----
122122
123- func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous (%6: tensor <?x?xf32 >, %arg0: index , %extracted_slice : tensor <?x?xf32 >) -> tensor <?x?xf32 > {
123+ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous (
124+ %src: tensor <?x?xf32 >,
125+ %output : tensor <?x?xf32 >,
126+ %idx: index ) -> tensor <?x?xf32 > {
127+
124128 %c79 = arith.constant 79 : index
125129 %1 = linalg.generic {
126130 indexing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>],
127131 iterator_types = [" parallel" , " parallel" ]
128- } outs (%extracted_slice : tensor <?x?xf32 >) {
132+ } outs (%output : tensor <?x?xf32 >) {
129133 ^bb0 (%out: f32 ):
130134 %2 = linalg.index 1 : index
131- %3 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%2 , %arg0 )
132- %extracted = tensor.extract %6 [%c79 , %3 ] : tensor <?x?xf32 >
135+ %3 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%2 , %idx )
136+ %extracted = tensor.extract %src [%c79 , %3 ] : tensor <?x?xf32 >
133137 linalg.yield %extracted : f32
134138 } -> tensor <?x?xf32 >
135139 return %1 : tensor <?x?xf32 >
136140}
137141
138142// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
139- // CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
140- // CHECK-SAME: %[[VAL_1:.*]]: index,
141- // CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
142- // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
143- // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
144- // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
145- // CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
146- // CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
147- // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
148- // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
149- // CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
150- // CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
151- // CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
152- // CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
153- // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
154- // CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
155- // CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
156- // CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
157- // CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
158- // CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
159- // CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
160- // CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
161- // CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
162- // CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
163- // CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
164- // CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
165- // CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
166- // CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
167- // CHECK: return %[[VAL_27]] : tensor<?x?xf32>
168- // CHECK: }
143+ // CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
144+ // CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
145+ // CHECK-SAME: %[[IDX:.*]]: index)
146+
147+ /// Create the mask
148+ // CHECK: %[[C79:.*]] = arith.constant 79 : index
149+ // CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
150+ // CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
151+ // CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
152+ // CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
153+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
154+
155+ /// TODO: This transfer_read is redundant - remove
156+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
157+
158+ /// Caluclate the index vector
159+ // CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
160+ // CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
161+ // CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
162+ // CHECK: %[[C0:.*]] = arith.constant 0 : i32
163+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
164+
165+ /// Extract the starting point from the index vector
166+ // CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
167+
168+ // Final read and write
169+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
170+ // CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
169171
170172module attributes {transform.with_named_sequence } {
171173 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
177179
178180// -----
179181
182+ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable (
183+ %src: tensor <?x?xf32 >,
184+ %output : tensor <?x?xf32 >,
185+ %idx: index ) -> tensor <?x?xf32 > {
186+
187+ %c79 = arith.constant 79 : index
188+ %1 = linalg.generic {
189+ indexing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>],
190+ iterator_types = [" parallel" , " parallel" ]
191+ } outs (%output : tensor <?x?xf32 >) {
192+ ^bb0 (%out: f32 ):
193+ %2 = linalg.index 1 : index
194+ %3 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%2 , %idx )
195+ %extracted = tensor.extract %src [%c79 , %3 ] : tensor <?x?xf32 >
196+ linalg.yield %extracted : f32
197+ } -> tensor <?x?xf32 >
198+ return %1 : tensor <?x?xf32 >
199+ }
200+
201+ // CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
202+ // CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
203+ // CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
204+ // CHECK-SAME: %[[IDX:.*]]: index)
205+
206+ /// Create the mask
207+ // CHECK: %[[C79:.*]] = arith.constant 79 : index
208+ // CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
209+ // CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
210+ // CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
211+ // CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
212+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
213+
214+ /// TODO: This transfer_read is redundant - remove
215+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
216+
217+ /// Caluclate the index vector
218+ // CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
219+ // CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
220+ // CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
221+ // CHECK: %[[C0:.*]] = arith.constant 0 : i32
222+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
223+
224+ /// Extract the starting point from the index vector
225+ // CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
226+
227+ // Final read and write
228+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
229+ // CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
230+
231+ module attributes {transform.with_named_sequence } {
232+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
233+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
234+ transform.structured.vectorize %0 vector_sizes [1 , [4 ]] {vectorize_nd_extract } : !transform.any_op
235+ transform.yield
236+ }
237+ }
238+
239+ // -----
240+
180241func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather (%6: tensor <80 x16 xf32 >, %arg0: index , %extracted_slice : tensor <1 x3 xf32 >) -> tensor <1 x3 xf32 > {
181242 %c16 = arith.constant 16 : index
182243 %1 = linalg.generic {
0 commit comments