@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
187187 transform.yield
188188 }
189189}
190+
191+ // -----
192+
193+
194+ // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
195+ // CHECK: func.func @transfer_read_reduce_rank_scalable(
196+ // CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
197+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
198+ // CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
199+ // CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
200+ // CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
201+ func.func @transfer_read_reduce_rank_scalable (%mem: memref <?x?x?x?xf32 >) -> vector <8 x[4 ]x2 x3 xf32 > {
202+ %c0 = arith.constant 0 : index
203+ %cst_0 = arith.constant 0.000000e+00 : f32
204+ %1 = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0
205+ {in_bounds = [true , true , true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>}
206+ : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
207+ return %1 : vector <8 x[4 ]x2 x3 xf32 >
208+ }
209+
210+ // Masked case not supported.
211+ // CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
212+ // CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
213+ // CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
214+ // CHECK-NOT: vector.broadcast
215+ // CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
216+ func.func @masked_transfer_read_reduce_rank (%mem: memref <?x?x?x?xf32 >, %dim: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
217+ %c0 = arith.constant 0 : index
218+ %cst_0 = arith.constant 0.000000e+00 : f32
219+ %mask = vector.create_mask %dim , %dim: vector <[4 ]x3 xi1 >
220+ %res = vector.mask %mask { vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0
221+ {in_bounds = [true , true , true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>}
222+ : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 > } : vector <[4 ]x3 xi1 > -> vector <8 x[4 ]x2 x3 xf32 >
223+ return %res : vector <8 x[4 ]x2 x3 xf32 >
224+ }
225+
226+ module attributes {transform.with_named_sequence } {
227+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
228+ %f = transform.structured.match ops {[" func.func" ]} in %module_op
229+ : (!transform.any_op ) -> !transform.any_op
230+ transform.apply_patterns to %f {
231+ transform.apply_patterns.vector.transfer_permutation_patterns
232+ } : !transform.any_op
233+ transform.yield
234+ }
235+ }
0 commit comments