@@ -363,17 +363,39 @@ func.func @xfer_read_minor_identity_transposed_masked_scalable(
363363}
364364
365365///----------------------------------------------------------------------------------------
366- /// vector.transfer_read
366+ /// [Pattern: TransferOpReduceRank]
367+ ///
368+ /// IN: vector.transfer_read (minor identity map + broadcast)
369+ /// OUT: vector.transfer_read + vector.broadcast
367370///----------------------------------------------------------------------------------------
368- /// TODO: Review and categorize
371+
372+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims
373+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x4x2x3xf32> {
374+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
375+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
376+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
377+ func.func @xfer_read_minor_identitiy_bcast_dims (
378+ %mem: memref <?x?x?x?xf32 >,
379+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
380+
381+ %pad = arith.constant 0.000000e+00 : f32
382+
383+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
384+ in_bounds = [true , true , true , true ],
385+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
386+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
387+
388+ return %res : vector <8 x4 x2 x3 xf32 >
389+ }
369390
370391// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
371392// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
372393// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
373394// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
374395// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
375396func.func @xfer_read_minor_identitiy_bcast_dims_scalable (
376- %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
397+ %mem: memref <?x?x?x?xf32 >,
398+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
377399
378400 %pad = arith.constant 0.000000e+00 : f32
379401
@@ -385,31 +407,80 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
385407 return %res : vector <8 x[4 ]x2 x3 xf32 >
386408}
387409
410+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask
411+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
412+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1>
413+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32>
414+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
415+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<4x2x3xf32>
416+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<4x2x3xf32> to vector<8x4x2x3xf32>
417+ // CHECK: return %[[BC]] : vector<8x4x2x3xf32>
418+ func.func @xfer_read_minor_identitiy_bcast_dims_with_mask (
419+ %mem: memref <?x?x?x?xf32 >,
420+ %mask: vector <4 x3 xi1 >,
421+ %idx: index ) -> vector <8 x4 x2 x3 xf32 > {
422+
423+ %pad = arith.constant 0.000000e+00 : f32
424+
425+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad , %mask {
426+ in_bounds = [true , true , true , true ],
427+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
428+ } : memref <?x?x?x?xf32 >, vector <8 x4 x2 x3 xf32 >
429+
430+ return %res : vector <8 x4 x2 x3 xf32 >
431+ }
432+
433+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable
434+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>
435+ // CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1>
436+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32>
437+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
438+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
439+ // CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
440+ // CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
441+ func.func @xfer_read_minor_identitiy_bcast_dims_with_mask_scalable (
442+ %mem: memref <?x?x?x?xf32 >,
443+ %mask: vector <[4 ]x3 xi1 >,
444+ %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
445+
446+ %pad = arith.constant 0.000000e+00 : f32
447+
448+ %res = vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad , %mask {
449+ in_bounds = [true , true , true , true ],
450+ permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
451+ } : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
452+
453+ return %res : vector <8 x[4 ]x2 x3 xf32 >
454+ }
455+
388456// Masked version is not supported
389457
390458// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
391459// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
392- // CHECK-SAME: %[[MASK:.*]]: vector<[4]x3xi1 >
393- // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32 > {
460+ // CHECK-SAME: %[[MASK:.*]]: vector<4x3xi1 >
461+ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x4x2x3xf32 > {
394462// CHECK-NOT: vector.broadcast
395- // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32 > } : vector<[4]x3xi1 > -> vector<8x[4]x2x3xf32 >
463+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x4x2x3xf32 > } : vector<4x3xi1 > -> vector<8x4x2x3xf32 >
396464func.func @xfer_read_minor_identitiy_bcast_dims_masked (
397465 %mem: memref <?x?x?x?xf32 >,
398- %mask: vector <[ 4 ]x 3 x i1 >,
399- %idx: index ) -> vector <8 x[ 4 ]x 2 x 3 x f32 > {
466+ %mask: vector <4 x 3 x i1 >,
467+ %idx: index ) -> vector <8 x 4 x 2 x 3 x f32 > {
400468
401469 %pad = arith.constant 0.000000e+00 : f32
402470
403471 %res = vector.mask %mask {
404472 vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
405473 in_bounds = [true , true , true , true ],
406474 permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>
407- } : memref <?x?x?x?xf32 >, vector <8 x[ 4 ]x 2 x 3 x f32 >
408- } : vector <[ 4 ]x 3 x i1 > -> vector <8 x[ 4 ]x 2 x 3 x f32 >
475+ } : memref <?x?x?x?xf32 >, vector <8 x 4 x 2 x 3 x f32 >
476+ } : vector <4 x 3 x i1 > -> vector <8 x 4 x 2 x 3 x f32 >
409477
410- return %res : vector <8 x[ 4 ]x 2 x 3 x f32 >
478+ return %res : vector <8 x 4 x 2 x 3 x f32 >
411479}
412480
481+ ///----------------------------------------------------------------------------------------
482+ // TD sequence
483+ ///----------------------------------------------------------------------------------------
413484module attributes {transform.with_named_sequence } {
414485 transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
415486 %f = transform.structured.match ops {[" func.func" ]} in %module_op
0 commit comments