diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7d615bfc12984..56f748fbbe1d6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend +/// 1s, are considered to be 'broadcastlike'. +static bool isBroadcastLike(Operation *op) { + if (isa(op)) + return true; + + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return false; + + // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). + // Checking that the destination shape has a prefix of 1s is not sufficient, + // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition + // is that the source shape is a suffix of the destination shape. + VectorType srcType = shapeCast.getSourceVectorType(); + ArrayRef srcShape = srcType.getShape(); + uint64_t srcRank = srcType.getRank(); + ArrayRef dstShape = shapeCast.getType().getShape(); + return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; +} + +/// Fold extract(broadcast(X)) to either extract(X) or just X. +/// +/// Example: +/// +/// broadcast extract [1][2] +/// (3, 4) --------> (2, 3, 4) ----------------> (4) +/// +/// becomes +/// extract [1] +/// (3,4) -------------------------------------> (4) +/// +/// +/// The variable names used in this implementation correspond to the above +/// shapes as, +/// +/// - (3, 4) is `input` shape. +/// - (2, 3, 4) is `broadcast` shape. +/// - (4) is `extract` shape. +/// +/// This folding is possible when the suffix of `input` shape is the same as +/// `extract` shape. static Value foldExtractFromBroadcast(ExtractOp extractOp) { + Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + if (!defOp || !isBroadcastLike(defOp)) return Value(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return source; - auto getRank = [](Type type) { - return llvm::isa(type) ? llvm::cast(type).getRank() - : 0; - }; + Value input = defOp->getOperand(0); - // If splat or broadcast from a scalar, just return the source scalar. - unsigned broadcastSrcRank = getRank(source.getType()); - if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) - return source; + // Replace extract(broadcast(X)) with X + if (extractOp.getType() == input.getType()) + return input; - unsigned extractResultRank = getRank(extractOp.getType()); - if (extractResultRank > broadcastSrcRank) - return Value(); - // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = llvm::dyn_cast(extractOp.getType()); - auto broadcastVecType = llvm::dyn_cast(source.getType()); - if (extractVecType && broadcastVecType && - extractVecType.getShape() != - broadcastVecType.getShape().take_back(extractResultRank)) + // Get required types and ranks in the chain + // input -> broadcast -> extract + // (scalars are treated as rank-0). + auto inputType = llvm::dyn_cast(input.getType()); + auto extractType = llvm::dyn_cast(extractOp.getType()); + unsigned inputRank = inputType ? inputType.getRank() : 0; + unsigned broadcastRank = extractOp.getSourceVectorType().getRank(); + unsigned extractRank = extractType ? extractType.getRank() : 0; + + // Cannot do without the broadcast if overall the rank increases. + if (extractRank > inputRank) return Value(); - auto broadcastOp = cast(defOp); - int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + // The above condition guarantees that input is a vector. + assert(inputType && "input must be a vector type because of previous checks"); + ArrayRef inputShape = inputType.getShape(); - // Detect all the positions that come from "dim-1" broadcasting. - // These dimensions correspond to "dim-1" broadcasted dims; set the mathching - // extract position to `0` when extracting from the source operand. - llvm::SetVector broadcastedUnitDims = - broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getMixedPosition()); - OpBuilder b(extractOp.getContext()); - int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; - for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) - if (broadcastedUnitDims.contains(i)) - extractPos[i] = b.getIndexAttr(0); - // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the - // matching extract position when extracting from the source operand. - int64_t rankDiff = broadcastSrcRank - extractResultRank; - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - auto [staticPos, dynPos] = decomposeMixedValues(extractPos); + // In the case where there is a broadcast dimension in the suffix, it is not + // possible to replace extract(broadcast(X)) with extract(X). Example: + // + // broadcast extract + // (1) --------> (3,4) ------> (4) + if (extractType && + extractType.getShape() != inputShape.take_back(extractRank)) + return Value(); + + // Replace extract(broadcast(X)) with extract(X). + // First, determine the new extraction position. + unsigned deltaOverall = inputRank - extractRank; + unsigned deltaBroadcast = broadcastRank - inputRank; + SmallVector oldPositions = extractOp.getMixedPosition(); + SmallVector newPositions(deltaOverall); + IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); + for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) { + newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; + } + auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(source), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(input), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2204,32 +2244,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { + Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + VectorType outType = dyn_cast(extractOp.getType()); + if (!defOp || !isBroadcastLike(defOp) || !outType) return failure(); Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return failure(); - auto getRank = [](Type type) { - return llvm::isa(type) - ? llvm::cast(type).getRank() - : 0; - }; - unsigned broadcastSrcRank = getRank(source.getType()); - unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is less than or - // equal to the rank of the extract dst. The other cases are handled in the - // folding patterns. - if (extractResultRank < broadcastSrcRank) - return failure(); - // For scalar result, the input can only be a rank-0 vector, which will - // be handled by the folder. - if (extractResultRank == 0) + if (isBroadcastableTo(source.getType(), outType) != + BroadcastableToResult::Success) return failure(); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), source); + rewriter.replaceOpWithNewOp(extractOp, outType, source); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 33177736eb5fe..1ed82954398f0 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector) { // CHECK-SAME: %[[VEC:.*]]: vector) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector to vector<1xf32> // CHECK: vector.print punctuation // CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32> +// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector // CHECK: vector.print %[[EL]] : f32 punctuation // CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index // CHECK: scf.if %[[IS_NOT_LAST]] { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index ea2343efd246e..6809122974545 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_splat +// CHECK-LABEL: fold_extract_scalar_from_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { +func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 @@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in // ----- +// CHECK-LABEL: fold_extract_vector_from_splat +// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> +func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting // CHECK-SAME: %[[A:.*]]: vector<2x1xf32> // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index @@ -863,6 +873,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is broadcast-like. +// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index +// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: return %[[B]] : vector<4xf32> +func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// Test where the shape_cast is not broadcast-like, even though it prepends 1s. +// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract +// CHECK-NEXT: return +func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<2xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32> + return %r : vector<2xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> @@ -890,6 +929,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // ----- +// CHECK-LABEL: fold_extract_broadcastlike_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: return %[[R]] : vector<1x1xf32> +func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) + -> vector<1x1xf32> { + %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32> + %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32> + return %r : vector<1x1xf32> +} + +// ----- + // CHECK-LABEL: @fold_extract_shuffle // CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> // CHECK-NOT: vector.shuffle @@ -1623,7 +1675,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1680,7 +1732,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32