@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
933933 }
934934}
935935
936+ // -----
937+
938+ ///----------------------------------------------------------------------------------------
939+ /// Tests for linalg.batch_batch_mmt4d
940+ ///----------------------------------------------------------------------------------------
941+
942+ func.func @batch_mmt4d (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x8 x1 xf32 >, %C_in: memref <2 x16 x16 x8 x8 xf32 >) {
943+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x8 x1 xf32 >)
944+ outs (%C_in: memref <2 x16 x16 x8 x8 xf32 >)
945+ return
946+ }
947+
948+ // CHECK-LABEL: func.func @batch_mmt4d(
949+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
950+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
951+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
952+ // CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
953+ // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
954+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
955+ // CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
956+
957+ module attributes {transform.with_named_sequence } {
958+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
959+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
960+ transform.structured.vectorize %batch_mmt4d : !transform.any_op
961+ transform.yield
962+ }
963+ }
964+
965+ // -----
966+
967+ func.func @batch_mmt4d_scalable (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x?x1 xf32 >, %C_in: memref <2 x16 x16 x8 x?xf32 >) {
968+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x?x1 xf32 >)
969+ outs (%C_in: memref <2 x16 x16 x8 x?xf32 >)
970+ return
971+ }
972+ // CHECK-LABEL: func.func @batch_mmt4d_scalable(
973+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
974+ // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
975+ // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
976+ // CHECK: %[[VAL_0:.*]] = arith.constant 2 : index
977+ // CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
978+ // CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
979+ // CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
980+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
981+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
982+ // CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
983+ // CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
984+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
985+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x[4]x1xi1>
986+ // CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
987+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]] : vector<2x16x16x8x[4]xi1>
988+ // CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
989+ // CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
990+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x16x8x[4]x1xi1>
991+ // CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
992+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
993+
994+ module attributes {transform.with_named_sequence } {
995+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
996+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
997+ transform.structured.vectorize %batch_mmt4d vector_sizes [2 , 16 , 16 , 16 , 8 , [4 ], 1 ] : !transform.any_op
998+ transform.yield
999+ }
1000+ }
1001+
1002+ // -----
1003+
1004+ func.func @batch_mmt4d_scalable_with_assume (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x?x1 xf32 >, %C_in: memref <2 x16 x16 x8 x?xf32 >) {
1005+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x?x1 xf32 >)
1006+ outs (%C_in: memref <2 x16 x16 x8 x?xf32 >)
1007+ return
1008+ }
1009+ // CHECK-LABEL: func.func @batch_mmt4d_scalable_with_assume(
1010+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
1011+ // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
1012+ // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
1013+ // CHECK-NOT: mask
1014+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1015+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016+ // CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1017+ // CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
1018+ // CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019+ // CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1020+
1021+ module attributes {transform.with_named_sequence } {
1022+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1023+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1024+ transform.structured.vectorize %batch_mmt4d vector_sizes [2 , 16 , 16 , 16 , 8 , [4 ], 1 ] {assume_dynamic_dims_match_vec_sizes } : !transform.any_op
1025+ transform.yield
1026+ }
1027+ }
1028+
1029+
9361030// -----
9371031
9381032///----------------------------------------------------------------------------------------
0 commit comments