Skip to content

Commit 8de85e7

Browse files
authored
[mlir][linalg] Add support for scalable vectorization of linalg.batch_mmt4d (#152984)
This PR builds upon the previous #146531 and enables scalable vectorization for `batch_mmt4d` as well. --------- Signed-off-by: Ege Beysel <[email protected]>
1 parent c96d0da commit 8de85e7

File tree

2 files changed

+110
-15
lines changed

2 files changed

+110
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,6 +2609,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26092609
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
26102610
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26112611
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2612+
isa<linalg::BatchMmt4DOp>(op) ||
26122613
hasReductionIterator(linalgOp));
26132614
}
26142615

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -880,22 +880,22 @@ func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>,
880880
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881881
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882882
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883-
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
884-
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
885-
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
883+
// CHECK: %[[C16_M:.*]] = arith.constant 16 : index
884+
// CHECK: %[[C16_N:.*]] = arith.constant 16 : index
885+
// CHECK: %[[C16_K:.*]] = arith.constant 16 : index
886886
// CHECK: %[[C8:.*]] = arith.constant 8 : index
887887
// CHECK: %[[C2:.*]] = arith.constant 2 : index
888888
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889-
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
889+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
890890
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891-
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
891+
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C16_N]], %[[C16_K]], %[[DIM_2]], %[[C1]] : vector<16x16x[4]x1xi1>
892892
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893-
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894-
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895-
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896-
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
897-
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898-
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
893+
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894+
// CHECK: %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896+
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_2]], %[[C1]] : vector<16x16x16x8x[4]x1xi1>
897+
// CHECK: %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898+
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899899

900900

901901
module attributes {transform.with_named_sequence} {
@@ -920,10 +920,10 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
920920
// CHECK-NOT: mask
921921
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922922
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923-
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924-
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925-
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926-
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
923+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927927

928928
module attributes {transform.with_named_sequence} {
929929
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
933933
}
934934
}
935935

936+
// -----
937+
938+
///----------------------------------------------------------------------------------------
939+
/// Tests for linalg.batch_mmt4d
940+
///----------------------------------------------------------------------------------------
941+
942+
func.func @batch_mmt4d(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x8x1xf32>, %C_in: memref<2x16x16x8x8xf32>) {
943+
linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x8x1xf32>)
944+
outs(%C_in: memref<2x16x16x8x8xf32>)
945+
return
946+
}
947+
948+
// CHECK-LABEL: func.func @batch_mmt4d(
949+
// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
950+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
951+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
952+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
953+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
954+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
955+
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
956+
957+
module attributes {transform.with_named_sequence} {
958+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
959+
%batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
960+
transform.structured.vectorize %batch_mmt4d : !transform.any_op
961+
transform.yield
962+
}
963+
}
964+
965+
// -----
966+
967+
func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
968+
linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
969+
outs(%C_in: memref<2x16x16x8x?xf32>)
970+
return
971+
}
972+
// CHECK-LABEL: func.func @batch_mmt4d_scalable(
973+
// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
974+
// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
975+
// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
976+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
977+
// CHECK: %[[C16_M:.*]] = arith.constant 16 : index
978+
// CHECK: %[[C16_N:.*]] = arith.constant 16 : index
979+
// CHECK: %[[C16_K:.*]] = arith.constant 16 : index
980+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
981+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
982+
// CHECK: %[[DIM_N_IN:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
983+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
984+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
985+
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C2]], %[[C16_N]], %[[C16_K]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x[4]x1xi1>
986+
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
987+
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_N_IN]] : vector<2x16x16x8x[4]xi1>
988+
// CHECK: %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
989+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
990+
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x16x8x[4]x1xi1>
991+
// CHECK: %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
992+
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
993+
994+
module attributes {transform.with_named_sequence} {
995+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
996+
%batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
997+
transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] : !transform.any_op
998+
transform.yield
999+
}
1000+
}
1001+
1002+
// -----
1003+
1004+
func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
1005+
linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
1006+
outs(%C_in: memref<2x16x16x8x?xf32>)
1007+
return
1008+
}
1009+
// CHECK-LABEL: func.func @batch_mmt4d_scalable_with_assume(
1010+
// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
1011+
// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
1012+
// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
1013+
// CHECK-NOT: mask
1014+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1015+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1017+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
1018+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1020+
1021+
module attributes {transform.with_named_sequence} {
1022+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1023+
%batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1024+
transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
1025+
transform.yield
1026+
}
1027+
}
1028+
1029+
9361030
// -----
9371031

9381032
///----------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)