Skip to content

Commit 154c550

Browse files
committed
[mlir][Vector] Improve dynamic support for vector.extract(broadcast) folders
1 parent 402efa7 commit 154c550

File tree

2 files changed

+67
-27
lines changed

2 files changed

+67
-27
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,10 +1648,6 @@ static bool hasZeroDimVectors(Operation *op) {
16481648

16491649
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
16501650
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1651-
// TODO: Canonicalization for dynamic position not implemented yet.
1652-
if (extractOp.hasDynamicPosition())
1653-
return Value();
1654-
16551651
Operation *defOp = extractOp.getVector().getDefiningOp();
16561652
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
16571653
return Value();
@@ -1680,6 +1676,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16801676
broadcastVecType.getShape().take_back(extractResultRank))
16811677
return Value();
16821678

1679+
// The dim-1 broadcast -> ExtractOp folder requires in place operation
1680+
// modifications. For dynamic position, this means we have to change the
1681+
// number of operands. This cannot be done in place since it changes the
1682+
// operation storage. For dynamic dimensions, the dim-1 broadcasting should
1683+
// be implemented as a canonicalization pattern.
1684+
// TODO: Implement canonicalization pattern for dim-1 broadcasting +
1685+
// extractop.
1686+
if (extractOp.hasDynamicPosition())
1687+
return Value();
1688+
16831689
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
16841690
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
16851691

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -652,24 +652,44 @@ func.func @fold_extract_transpose(
652652

653653
// -----
654654

655-
// CHECK-LABEL: fold_extract_broadcast
655+
// CHECK-LABEL: fold_extract_broadcast_same_type
656656
// CHECK-SAME: %[[A:.*]]: f32
657657
// CHECK: return %[[A]] : f32
658-
func.func @fold_extract_broadcast(%a : f32) -> f32 {
658+
func.func @fold_extract_broadcast_same_type(%a : f32,
659+
%idx0 : index,
660+
%idx1 : index) -> f32 {
659661
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
660-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
662+
// The indices don't batter for this folder, so we use mixed indices.
663+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
661664
return %r : f32
662665
}
663666

664667
// -----
665668

666-
// CHECK-LABEL: fold_extract_broadcast_0dvec
669+
// CHECK-LABEL: fold_extract_broadcast_same_type_vec
670+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
671+
// CHECK: return %[[A]] : vector<4xf32>
672+
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
673+
%idx0 : index)
674+
-> vector<4xf32> {
675+
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
676+
// The indices don't batter for this folder, so we use mixed indices.
677+
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
678+
return %r : vector<4xf32>
679+
}
680+
681+
// -----
682+
683+
// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
667684
// CHECK-SAME: %[[A:.*]]: vector<f32>
668685
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
669686
// CHECK: return %[[B]] : f32
670-
func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
687+
func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>,
688+
%idx0 : index,
689+
%idx1 : index) -> f32 {
671690
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
672-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
691+
// The indices don't batter for this folder, so we use mixed indices.
692+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
673693
return %r : f32
674694
}
675695

@@ -689,57 +709,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
689709
// CHECK-LABEL: fold_extract_splat
690710
// CHECK-SAME: %[[A:.*]]: f32
691711
// CHECK: return %[[A]] : f32
692-
func.func @fold_extract_splat(%a : f32) -> f32 {
712+
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
693713
%b = vector.splat %a : vector<1x2x4xf32>
694-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
714+
// The indices don't batter for this folder, so we use mixed indices.
715+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
695716
return %r : f32
696717
}
697718

698719
// -----
699720

700-
// CHECK-LABEL: fold_extract_broadcast_vector
721+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
701722
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
702-
// CHECK: return %[[A]] : vector<4xf32>
703-
func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
723+
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
724+
// CHECK: return %[[R]] : f32
725+
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
704726
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
705-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
706-
return %r : vector<4xf32>
727+
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
728+
return %r : f32
707729
}
708730

709731
// -----
710732

711-
// CHECK-LABEL: fold_extract_broadcast
733+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
712734
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
713-
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
735+
// CHECK-SAME: %[[IDX:.*]]: index
736+
// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
737+
// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
714738
// CHECK: return %[[R]] : f32
715-
func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
739+
// This folder is not yet implemented. Check that this does not fold.
740+
func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
741+
%a : vector<4xf32>,
742+
%idx : index) -> f32 {
716743
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
717-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
744+
%r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
718745
return %r : f32
719746
}
720747

721748
// -----
722749

723-
// CHECK-LABEL: fold_extract_broadcast
750+
// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
724751
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
725752
// CHECK: return %[[B]] : vector<4xf32>
726-
func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
753+
func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32,
754+
%idx0 : index)
755+
-> vector<4xf32> {
727756
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
728-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
757+
// The indices don't batter for this canonicalizer, so we use mixed indices.
758+
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
729759
return %r : vector<4xf32>
730760
}
731761

732762
// -----
733763

734-
// CHECK-LABEL: fold_extract_broadcast
764+
// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
735765
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
736766
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
737767
// CHECK: return %[[R]] : vector<8xf32>
738-
func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
768+
func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
769+
%idx0 : index)
770+
-> vector<8xf32> {
739771
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
740-
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
772+
// The indices don't batter for this canonicalizer, so we use mixed indices.
773+
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
741774
return %r : vector<8xf32>
742775
}
776+
743777
// -----
744778

745779
// CHECK-LABEL: @fold_extract_shuffle

0 commit comments

Comments
 (0)