Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,46 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return extractOp.getResult();
}

/// Fold extractOp coming from ShuffleOp.
///
/// Example:
///
/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
/// : vector<8xf32>, vector<8xf32>
/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
/// ->
/// %extract = vector.extract %b[7] : f32 from vector<4xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// %extract = vector.extract %b[7] : f32 from vector<4xf32>
/// %extract = vector.extract %b[7] : f32 from vector<8xf32>

///
static Value foldExtractFromShuffle(ExtractOp extractOp) {
// TODO: Canonicalization for dynamic position not implemented yet.
if (extractOp.hasDynamicPosition())
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, how do we support the dynamic case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only thing that comes to my mind is a series of compare and branches?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I'm thinking, and I'm wondering if there are other approaches. This approach looks bad to me because it introduces branching in the IR. If we don't have a good plan, perhaps we should drop the TODO and make an explicit comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO is inherited from similar canonicalizations. Not sure if branches but at least quite some instructions, including more vector extracts and select ops. Definitely not a good canonical form for dynamic cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Yes, I also think that it is not a good canonical form. It's good that we're on the same page..


auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
if (!shuffleOp)
return Value();

// TODO: 0-D or multi-dimensional vectors not supported yet.
if (shuffleOp.getResultVectorType().getRank() != 1)
return Value();

int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
auto shuffleMask = shuffleOp.getMask();
int64_t extractIdx = extractOp.getStaticPosition()[0];
int64_t shuffleIdx = shuffleMask[extractIdx];

// Find the shuffled vector to extract from based on the shuffle index.
if (shuffleIdx < inputVecSize) {
extractOp.setOperand(0, shuffleOp.getV1());
extractOp.setStaticPosition({shuffleIdx});
} else {
extractOp.setOperand(0, shuffleOp.getV2());
extractOp.setStaticPosition({shuffleIdx - inputVecSize});
}

return extractOp.getResult();
}

// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
// TODO: Canonicalization for dynamic position not implemented yet.
Expand Down Expand Up @@ -1953,6 +1993,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return res;
if (auto res = foldExtractFromBroadcast(*this))
return res;
if (auto res = foldExtractFromShuffle(*this))
return res;
if (auto res = foldExtractFromShapeCast(*this))
return res;
if (auto val = foldExtractFromExtractStrided(*this))
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
// -----

// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32>
// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32>
// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32>
// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32>
func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
-> (f32, f32, f32, f32) {
%shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
%e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
%e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
%e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
%e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
}

// -----

Expand Down
Loading