@@ -880,22 +880,22 @@ func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>,
880880// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881881// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882882// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883- // CHECK: %[[VAL_0 :.*]] = arith.constant 16 : index
884- // CHECK: %[[VAL_1 :.*]] = arith.constant 16 : index
885- // CHECK: %[[VAL_2 :.*]] = arith.constant 16 : index
883+ // CHECK: %[[C16_M :.*]] = arith.constant 16 : index
884+ // CHECK: %[[C16_N :.*]] = arith.constant 16 : index
885+ // CHECK: %[[C16_K :.*]] = arith.constant 16 : index
886886// CHECK: %[[C8:.*]] = arith.constant 8 : index
887887// CHECK: %[[C2:.*]] = arith.constant 2 : index
888888// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889- // CHECK: %[[VAL_6 :.*]] = arith.constant 1 : index
889+ // CHECK: %[[C1 :.*]] = arith.constant 1 : index
890890// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891- // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1 ]], %[[VAL_2 ]], %[[DIM_2]], %[[VAL_6 ]] : vector<16x16x[4]x1xi1>
891+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C16_N ]], %[[C16_K ]], %[[DIM_2]], %[[C1 ]] : vector<16x16x[4]x1xi1>
892892// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893- // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894- // CHECK: %[[VAL_15 :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895- // CHECK: %[[VAL_16 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896- // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[VAL_2 ]], %[[C8]], %[[DIM_2]], %[[VAL_6 ]] : vector<16x16x16x8x[4]x1xi1>
897- // CHECK: %[[VAL_18 :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16 ]], %[[VAL_15 ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898- // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18 ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
893+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C16_M ]], %[[C16_N ]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894+ // CHECK: %[[VEC_C :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C16_M ]], %[[C16_N ]], %[[C16_K ]], %[[C8]], %[[DIM_2]], %[[C1 ]] : vector<16x16x16x8x[4]x1xi1>
897+ // CHECK: %[[RED :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899899
900900
901901module attributes {transform.with_named_sequence } {
@@ -920,10 +920,10 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
920920// CHECK-NOT: mask
921921// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922922// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923- // CHECK: %[[VAL_13 :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924- // CHECK: %[[VAL_14 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925- // CHECK: %[[VAL_15 :.*]] = vector.multi_reduction <add>, %[[VAL_14 ]], %[[VAL_13 ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926- // CHECK: vector.transfer_write %[[VAL_15 ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
923+ // CHECK: %[[VEC_C :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925+ // CHECK: %[[RED :.*]] = vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926+ // CHECK: vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927927
928928module attributes {transform.with_named_sequence } {
929929 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -936,7 +936,7 @@ module attributes {transform.with_named_sequence} {
936936// -----
937937
938938///----------------------------------------------------------------------------------------
939- /// Tests for linalg.batch_batch_mmt4d
939+ /// Tests for linalg.batch_mmt4d
940940///----------------------------------------------------------------------------------------
941941
942942func.func @batch_mmt4d (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x8 x1 xf32 >, %C_in: memref <2 x16 x16 x8 x8 xf32 >) {
@@ -973,23 +973,23 @@ func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x
973973// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
974974// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
975975// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
976- // CHECK: %[[VAL_0 :.*]] = arith.constant 2 : index
977- // CHECK: %[[VAL_1 :.*]] = arith.constant 16 : index
978- // CHECK: %[[VAL_2 :.*]] = arith.constant 16 : index
979- // CHECK: %[[VAL_3 :.*]] = arith.constant 16 : index
976+ // CHECK: %[[C2 :.*]] = arith.constant 2 : index
977+ // CHECK: %[[C16_M :.*]] = arith.constant 16 : index
978+ // CHECK: %[[C16_N :.*]] = arith.constant 16 : index
979+ // CHECK: %[[C16_K :.*]] = arith.constant 16 : index
980980// CHECK: %[[C8:.*]] = arith.constant 8 : index
981981// CHECK: %[[C3:.*]] = arith.constant 3 : index
982- // CHECK: %[[DIM_2 :.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
983- // CHECK: %[[VAL_6 :.*]] = arith.constant 1 : index
982+ // CHECK: %[[DIM_N_IN :.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
983+ // CHECK: %[[C1 :.*]] = arith.constant 1 : index
984984// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
985- // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_2 ]], %[[VAL_3 ]], %[[DIM_2 ]], %[[VAL_6 ]] : vector<2x16x16x[4]x1xi1>
985+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C2 ]], %[[C16_N ]], %[[C16_K ]], %[[DIM_N_IN ]], %[[C1 ]] : vector<2x16x16x[4]x1xi1>
986986// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
987- // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[VAL_2 ]], %[[C8]], %[[DIM_2 ]] : vector<2x16x16x8x[4]xi1>
988- // CHECK: %[[VAL_15 :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
989- // CHECK: %[[VAL_16 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
990- // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[VAL_2 ]], %[[VAL_3 ]], %[[C8]], %[[DIM_2 ]], %[[VAL_6 ]] : vector<2x16x16x16x8x[4]x1xi1>
991- // CHECK: %[[VAL_18 :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16 ]], %[[VAL_15 ]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
992- // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18 ]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
987+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C2 ]], %[[C16_M ]], %[[C16_N ]], %[[C8]], %[[DIM_N_IN ]] : vector<2x16x16x8x[4]xi1>
988+ // CHECK: %[[VEC_C :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
989+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
990+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C2 ]], %[[C16_M ]], %[[C16_N ]], %[[C16_K ]], %[[C8]], %[[DIM_N_IN ]], %[[C1 ]] : vector<2x16x16x16x8x[4]x1xi1>
991+ // CHECK: %[[RED :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
992+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
993993
994994module attributes {transform.with_named_sequence } {
995995 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -1013,10 +1013,10 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem
10131013// CHECK-NOT: mask
10141014// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
10151015// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016- // CHECK: %[[VAL_13 :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1017- // CHECK: %[[VAL_14 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
1018- // CHECK: %[[VAL_15 :.*]] = vector.multi_reduction <add>, %[[VAL_14 ]], %[[VAL_13 ]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019- // CHECK: vector.transfer_write %[[VAL_15 ]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1016+ // CHECK: %[[VEC_C :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1017+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
1018+ // CHECK: %[[RED :.*]] = vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019+ // CHECK: vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
10201020
10211021module attributes {transform.with_named_sequence } {
10221022 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
0 commit comments