@@ -144,43 +144,40 @@ module attributes {transform.with_named_sequence} {
144144
145145// -----
146146
147- #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
148- func.func @vectorize_linalg_index (%arg0: tensor <3 x 3 x ?xf32 >, %arg1: tensor <1 x 1 x ?xf32 >) -> tensor <1 x 1 x ?xf32 > {
147+ #map = affine_map <(d0 ) -> (d0 )>
148+ func.func @vectorize_linalg_index (%arg0: tensor <?xf32 >, %arg1: tensor <?xf32 >) -> tensor <?xf32 > {
149149 %0 = linalg.generic {
150150 indexing_maps = [#map ],
151- iterator_types = [" parallel" , " parallel " , " parallel " ]
152- } outs (%arg1 : tensor <1 x 1 x ?xf32 >) {
151+ iterator_types = [" parallel" ]
152+ } outs (%arg1 : tensor <?xf32 >) {
153153 ^bb0 (%in: f32 ):
154154 %1 = linalg.index 0 : index
155- %2 = linalg.index 1 : index
156- %3 = linalg.index 2 : index
157- %4 = tensor.extract %arg0 [%1 , %2 , %3 ] : tensor <3 x3 x?xf32 >
158- linalg.yield %4 : f32
159- } -> tensor <1 x1 x?xf32 >
160- return %0 : tensor <1 x1 x?xf32 >
155+ %2 = tensor.extract %arg0 [%1 ] : tensor <?xf32 >
156+ linalg.yield %2 : f32
157+ } -> tensor <?xf32 >
158+ return %0 : tensor <?xf32 >
161159}
162160
163161// CHECK-LABEL: @vectorize_linalg_index
164- // CHECK-SAME: %[[SRC:.*]]: tensor<3x3x ?xf32>, %[[DST:.*]]: tensor<1x1x ?xf32>
165- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
166- // CHECK-DAG : %[[C1 :.*]] = arith.constant 1 : index
167- // CHECK-DAG : %[[C2 :.*]] = arith.constant 2 : index
168- // CHECK: %[[DST_DIM2:.* ]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32 >
169- // CHECK: %[[MASK:.* ]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x [4]xi1 >
170- // CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
171- // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2] , %cst {in_bounds = [true, true, true ]} : tensor<3x3x ?xf32>, vector<1x1x [4]xf32> } : vector<1x1x [4]xi1> -> vector<1x1x [4]xf32>
172- // CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true ]} : vector<1x1x [4]xf32>, tensor<1x1x ?xf32> } : vector<1x1x [4]xi1> -> tensor<1x1x ?xf32>
173- // CHECK: return %[[OUT]] : tensor<1x1x ?xf32>
162+ // CHECK-SAME: %[[SRC:.*]]: tensor<?xf32>, %[[DST:.*]]: tensor<?xf32>
163+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
164+ // CHECK: %[[DST_DIM0 :.*]] = tensor.dim %[[DST]], %[[C0]] : tensor<?xf32>
165+ // CHECK: %[[MASK :.*]] = vector.create_mask %[[DST_DIM0]] : vector<[4]xi1>
166+ // CHECK-DAG : %[[STEP:.+ ]] = vector.step : vector<[4]xindex >
167+ // CHECK-DAG : %[[STEP_ELEMENT:.+ ]] = vector.extractelement %[[STEP]][%c0_i32 : i32] : vector<[4]xindex >
168+
169+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP_ELEMENT]]] , %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
170+ // CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
171+ // CHECK: return %[[OUT]] : tensor<?xf32>
174172
175173module attributes {transform.with_named_sequence } {
176174 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
177175 %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
178- transform.structured.vectorize %0 vector_sizes [1 , 1 , [4 ]] {vectorize_nd_extract } : !transform.any_op
176+ transform.structured.vectorize %0 vector_sizes [[4 ]] {vectorize_nd_extract } : !transform.any_op
179177
180178 %func = transform.structured.match ops {[" func.func" ]} in %arg1
181179 : (!transform.any_op ) -> !transform.any_op
182180 transform.apply_patterns to %func {
183- transform.apply_patterns.canonicalization
184181 transform.apply_patterns.linalg.tiling_canonicalization
185182 } : !transform.any_op
186183 transform.yield
0 commit comments