From a0da17ca5a5da46212aa913ae697fac9e402fcc5 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 5 Nov 2024 17:19:38 -0800 Subject: [PATCH 1/2] [mlir][Vector] Add vector.extract(vector.shuffle) folder This PR adds a folder for extracting an element from a vector shuffle. It turns something like: ``` %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32> %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> ``` into: ``` %extract = vector.extract %b[7] : f32 from vector<4xf32> ``` --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 42 ++++++++++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++ 2 files changed, 60 insertions(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d8913251e56e9..723044aa2b66e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1705,6 +1705,46 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { return extractOp.getResult(); } +/// Fold extractOp coming from ShuffleOp. +/// +/// Example: +/// +/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] +/// : vector<8xf32>, vector<8xf32> +/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> +/// -> +/// %extract = vector.extract %b[7] : f32 from vector<4xf32> +/// +static Value foldExtractFromShuffle(ExtractOp extractOp) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + + auto shuffleOp = extractOp.getVector().getDefiningOp(); + if (!shuffleOp) + return Value(); + + // TODO: 0-D or multi-dimensional vectors not supported yet. + if (shuffleOp.getResultVectorType().getRank() != 1) + return Value(); + + int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0]; + auto shuffleMask = shuffleOp.getMask(); + int64_t extractIdx = extractOp.getStaticPosition()[0]; + int64_t shuffleIdx = shuffleMask[extractIdx]; + + // Find the shuffled vector to extract from based on the shuffle index. + if (shuffleIdx < inputVecSize) { + extractOp.setOperand(0, shuffleOp.getV1()); + extractOp.setStaticPosition({shuffleIdx}); + } else { + extractOp.setOperand(0, shuffleOp.getV2()); + extractOp.setStaticPosition({shuffleIdx - inputVecSize}); + } + + return extractOp.getResult(); +} + // Fold extractOp with source coming from ShapeCast op. static Value foldExtractFromShapeCast(ExtractOp extractOp) { // TODO: Canonicalization for dynamic position not implemented yet. @@ -1953,6 +1993,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) { return res; if (auto res = foldExtractFromBroadcast(*this)) return res; + if (auto res = foldExtractFromShuffle(*this)) + return res; if (auto res = foldExtractFromShapeCast(*this)) return res; if (auto val = foldExtractFromExtractStrided(*this)) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index df87f86765a3a..5ae769090dac6 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32> return %r : vector<8xf32> } +// ----- + +// CHECK-LABEL: @fold_extract_shuffle +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> +// CHECK-NOT: vector.shuffle +// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32> +// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32> +// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32> +// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32> +func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>) + -> (f32, f32, f32, f32) { + %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32> + %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32> + %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32> + %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32> + %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32> + return %e0, %e1, %e2, %e3 : f32, f32, f32, f32 +} // ----- From 4c814a2449cd3eb2eec6f97ec23e3c73cf356ee1 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 6 Nov 2024 15:31:03 -0800 Subject: [PATCH 2/2] Feedback --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 723044aa2b66e..db199a46e1637 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1713,10 +1713,11 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { /// : vector<8xf32>, vector<8xf32> /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> /// -> -/// %extract = vector.extract %b[7] : f32 from vector<4xf32> +/// %extract = vector.extract %b[7] : f32 from vector<8xf32> /// static Value foldExtractFromShuffle(ExtractOp extractOp) { - // TODO: Canonicalization for dynamic position not implemented yet. + // Dynamic positions are not folded as the resulting code would be more + // complex than the input code. if (extractOp.hasDynamicPosition()) return Value();