@@ -880,22 +880,22 @@ func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>,
880
880
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881
881
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882
882
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883
- // CHECK: %[[VAL_0 :.*]] = arith.constant 16 : index
884
- // CHECK: %[[VAL_1 :.*]] = arith.constant 16 : index
885
- // CHECK: %[[VAL_2 :.*]] = arith.constant 16 : index
883
+ // CHECK: %[[C16_M :.*]] = arith.constant 16 : index
884
+ // CHECK: %[[C16_N :.*]] = arith.constant 16 : index
885
+ // CHECK: %[[C16_K :.*]] = arith.constant 16 : index
886
886
// CHECK: %[[C8:.*]] = arith.constant 8 : index
887
887
// CHECK: %[[C2:.*]] = arith.constant 2 : index
888
888
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889
- // CHECK: %[[VAL_6 :.*]] = arith.constant 1 : index
889
+ // CHECK: %[[C1 :.*]] = arith.constant 1 : index
890
890
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891
- // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1 ]], %[[VAL_2 ]], %[[DIM_2]], %[[VAL_6 ]] : vector<16x16x[4]x1xi1>
891
+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C16_N ]], %[[C16_K ]], %[[DIM_2]], %[[C1 ]] : vector<16x16x[4]x1xi1>
892
892
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893
- // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894
- // CHECK: %[[VAL_15 :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895
- // CHECK: %[[VAL_16 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896
- // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0 ]], %[[VAL_1 ]], %[[VAL_2 ]], %[[C8]], %[[DIM_2]], %[[VAL_6 ]] : vector<16x16x16x8x[4]x1xi1>
897
- // CHECK: %[[VAL_18 :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16 ]], %[[VAL_15 ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898
- // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18 ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
893
+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C16_M ]], %[[C16_N ]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894
+ // CHECK: %[[VEC_C :.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895
+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896
+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C16_M ]], %[[C16_N ]], %[[C16_K ]], %[[C8]], %[[DIM_2]], %[[C1 ]] : vector<16x16x16x8x[4]x1xi1>
897
+ // CHECK: %[[RED :.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898
+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899
899
900
900
901
901
module attributes {transform.with_named_sequence } {
@@ -920,10 +920,10 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
920
920
// CHECK-NOT: mask
921
921
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922
922
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923
- // CHECK: %[[VAL_13 :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924
- // CHECK: %[[VAL_14 :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925
- // CHECK: %[[VAL_15 :.*]] = vector.multi_reduction <add>, %[[VAL_14 ]], %[[VAL_13 ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926
- // CHECK: vector.transfer_write %[[VAL_15 ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
923
+ // CHECK: %[[VEC_C :.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924
+ // CHECK: %[[MUL :.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925
+ // CHECK: %[[RED :.*]] = vector.multi_reduction <add>, %[[MUL ]], %[[VEC_C ]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926
+ // CHECK: vector.transfer_write %[[RED ]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927
927
928
928
module attributes {transform.with_named_sequence } {
929
929
transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
933
933
}
934
934
}
935
935
936
+ // -----
937
+
938
+ ///----------------------------------------------------------------------------------------
939
+ /// Tests for linalg.batch_mmt4d
940
+ ///----------------------------------------------------------------------------------------
941
+
942
+ func.func @batch_mmt4d (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x8 x1 xf32 >, %C_in: memref <2 x16 x16 x8 x8 xf32 >) {
943
+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x8 x1 xf32 >)
944
+ outs (%C_in: memref <2 x16 x16 x8 x8 xf32 >)
945
+ return
946
+ }
947
+
948
+ // CHECK-LABEL: func.func @batch_mmt4d(
949
+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
950
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
951
+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
952
+ // CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
953
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
954
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
955
+ // CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
956
+
957
+ module attributes {transform.with_named_sequence } {
958
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
959
+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
960
+ transform.structured.vectorize %batch_mmt4d : !transform.any_op
961
+ transform.yield
962
+ }
963
+ }
964
+
965
+ // -----
966
+
967
+ func.func @batch_mmt4d_scalable (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x?x1 xf32 >, %C_in: memref <2 x16 x16 x8 x?xf32 >) {
968
+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x?x1 xf32 >)
969
+ outs (%C_in: memref <2 x16 x16 x8 x?xf32 >)
970
+ return
971
+ }
972
+ // CHECK-LABEL: func.func @batch_mmt4d_scalable(
973
+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
974
+ // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
975
+ // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
976
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
977
+ // CHECK: %[[C16_M:.*]] = arith.constant 16 : index
978
+ // CHECK: %[[C16_N:.*]] = arith.constant 16 : index
979
+ // CHECK: %[[C16_K:.*]] = arith.constant 16 : index
980
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
981
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
982
+ // CHECK: %[[DIM_N_IN:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
983
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
984
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
985
+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C2]], %[[C16_N]], %[[C16_K]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x[4]x1xi1>
986
+ // CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
987
+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_N_IN]] : vector<2x16x16x8x[4]xi1>
988
+ // CHECK: %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
989
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
990
+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x16x8x[4]x1xi1>
991
+ // CHECK: %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
992
+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
993
+
994
+ module attributes {transform.with_named_sequence } {
995
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
996
+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
997
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2 , 16 , 16 , 16 , 8 , [4 ], 1 ] : !transform.any_op
998
+ transform.yield
999
+ }
1000
+ }
1001
+
1002
+ // -----
1003
+
1004
+ func.func @batch_mmt4d_scalable_with_assume (%A: memref <2 x16 x16 x8 x1 xf32 >, %B: memref <2 x16 x16 x?x1 xf32 >, %C_in: memref <2 x16 x16 x8 x?xf32 >) {
1005
+ linalg.batch_mmt4d ins (%A , %B: memref <2 x16 x16 x8 x1 xf32 >, memref <2 x16 x16 x?x1 xf32 >)
1006
+ outs (%C_in: memref <2 x16 x16 x8 x?xf32 >)
1007
+ return
1008
+ }
1009
+ // CHECK-LABEL: func.func @batch_mmt4d_scalable_with_assume(
1010
+ // CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
1011
+ // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
1012
+ // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
1013
+ // CHECK-NOT: mask
1014
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1015
+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016
+ // CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1017
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
1018
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019
+ // CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1020
+
1021
+ module attributes {transform.with_named_sequence } {
1022
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1023
+ %batch_mmt4d = transform.structured.match ops {[" linalg.batch_mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1024
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2 , 16 , 16 , 16 , 8 , [4 ], 1 ] {assume_dynamic_dims_match_vec_sizes } : !transform.any_op
1025
+ transform.yield
1026
+ }
1027
+ }
1028
+
1029
+
936
1030
// -----
937
1031
938
1032
///----------------------------------------------------------------------------------------
0 commit comments