-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Add pattern for dropping unit dims from for loops #109585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-scf Author: Quinn Dawkins (qedawkins) ChangesThis adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes. Full diff: https://github.com/llvm/llvm-project/pull/109585.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 644118ca884c6b..d89d566ece62c1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
function_ref<void(OpBuilder &, Location, ValueRange)>
bodyBuilder = nullptr);
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value with a different type. A callback is used to insert
+/// cast ops inside the block to account for type differences.
+using ValueTypeCastFnTy =
+ std::function<Value(OpBuilder &, Location loc, Type, Value)>;
+SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
+ scf::ForOp forOp,
+ OpOperand &operand,
+ Value replacement,
+ const ValueTypeCastFnTy &castFn);
+
} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_SCF_SCF_H
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..d1c9fd2d217dad 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
});
}
+SmallVector<Value>
+mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
+ OpOperand &operand, Value replacement,
+ const ValueTypeCastFnTy &castFn) {
+ assert(operand.getOwner() == forOp);
+ Type oldType = operand.get().getType(), newType = replacement.getType();
+
+ // 1. Create new iter operands, exactly 1 is replaced.
+ assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+ "expected an iter OpOperand");
+ assert(operand.get().getType() != replacement.getType() &&
+ "Expected a different type");
+ SmallVector<Value> newIterOperands;
+ for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
+ if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+ newIterOperands.push_back(replacement);
+ continue;
+ }
+ newIterOperands.push_back(opOperand.get());
+ }
+
+ // 2. Create the new forOp shell.
+ scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newIterOperands);
+ newForOp->setAttrs(forOp->getAttrs());
+ Block &newBlock = newForOp.getRegion().front();
+ SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+ newBlock.getArguments().end());
+
+ // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+ // corresponding to the `replacement` value.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+ BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+ &newForOp->getOpOperand(operand.getOperandNumber()));
+ Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
+ newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+ // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+ Block &oldBlock = forOp.getRegion().front();
+ rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+ // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+ auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ rewriter.setInsertionPoint(clonedYieldOp);
+ unsigned yieldIdx =
+ newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+ Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
+ clonedYieldOp.getOperand(yieldIdx));
+ SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+ newYieldOperands[yieldIdx] = castOut;
+ rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+ rewriter.eraseOp(clonedYieldOp);
+
+ // 6. Inject an outgoing cast op after the forOp.
+ rewriter.setInsertionPointAfter(newForOp);
+ SmallVector<Value> newResults = newForOp.getResults();
+ newResults[yieldIdx] =
+ castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+ return newResults;
+}
+
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
}
};
-/// Perform a replacement of one iter OpOperand of an scf.for to the
-/// `replacement` value which is expected to be the source of a tensor.cast.
-/// tensor.cast ops are inserted inside the block to account for the type cast.
-static SmallVector<Value>
-replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
- Value replacement) {
- Type oldType = operand.get().getType(), newType = replacement.getType();
- assert(llvm::isa<RankedTensorType>(oldType) &&
- llvm::isa<RankedTensorType>(newType) &&
- "expected ranked tensor types");
-
- // 1. Create new iter operands, exactly 1 is replaced.
- ForOp forOp = cast<ForOp>(operand.getOwner());
- assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
- "expected an iter OpOperand");
- assert(operand.get().getType() != replacement.getType() &&
- "Expected a different type");
- SmallVector<Value> newIterOperands;
- for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
- if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
- newIterOperands.push_back(replacement);
- continue;
- }
- newIterOperands.push_back(opOperand.get());
- }
-
- // 2. Create the new forOp shell.
- scf::ForOp newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
- newForOp->setAttrs(forOp->getAttrs());
- Block &newBlock = newForOp.getRegion().front();
- SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
- newBlock.getArguments().end());
-
- // 3. Inject an incoming cast op at the beginning of the block for the bbArg
- // corresponding to the `replacement` value.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(&newBlock, newBlock.begin());
- BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
- &newForOp->getOpOperand(operand.getOperandNumber()));
- Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
- newRegionIterArg);
- newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
-
- // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
- Block &oldBlock = forOp.getRegion().front();
- rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-
- // 5. Inject an outgoing cast op at the end of the block and yield it instead.
- auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- rewriter.setInsertionPoint(clonedYieldOp);
- unsigned yieldIdx =
- newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
- Value castOut = rewriter.create<tensor::CastOp>(
- newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
- SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
- newYieldOperands[yieldIdx] = castOut;
- rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
- rewriter.eraseOp(clonedYieldOp);
-
- // 6. Inject an outgoing cast op after the forOp.
- rewriter.setInsertionPointAfter(newForOp);
- SmallVector<Value> newResults = newForOp.getResults();
- newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
- newForOp.getLoc(), oldType, newResults[yieldIdx]);
-
- return newResults;
-}
-
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
///
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
continue;
// Create a new ForOp with that iter operand replaced.
+ ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
+ Value source) {
+ return b.create<tensor::CastOp>(loc, type, source);
+ };
rewriter.replaceOp(
- op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
- incomingCast.getSource()));
+ op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
+ incomingCast.getSource(), castFn));
return success();
}
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad4e42b31962e1..ba32583fc3cdc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1796,6 +1796,63 @@ struct DropUnitDimsFromTransposeOp final
}
};
+/// A pattern to drop unit dims from the iter_args of an scf.for.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
+/// ...
+/// scf.yield %
+/// }
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %drop = vector.shape_cast %init
+/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
+/// %new_iter = vector.shape_cast %iter
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ...
+/// }
+/// %res = vector.shape_cast %new_loop
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ```
+struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ /// Find the first iter_arg with droppable unit dims. Further applications
+ /// of this pattern will apply to later arguments.
+ for (OpOperand &operand : forOp.getInitArgsMutable()) {
+ auto vectorType = dyn_cast<VectorType>(operand.get().getType());
+ if (!vectorType)
+ continue;
+
+ VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
+ if (vectorType == newVectorType)
+ continue;
+
+ // Create a new ForOp with that iter operand replaced.
+ mlir::scf::ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc,
+ Type type, Value source) {
+ return b.create<vector::ShapeCastOp>(loc, type, source);
+ };
+
+ Value replacement =
+ castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
+ rewriter.replaceOp(forOp,
+ replaceAndCastForOpIterArg(rewriter, forOp, operand,
+ replacement, castFn));
+ return success();
+ }
+ return failure();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -2001,7 +2058,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
- ShapeCastOpFolder>(patterns.getContext(), benefit);
+ ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index af3fc924c1dbe7..8249400a43c757 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -207,3 +207,78 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect
// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
// CHECK-NOT: vector.shape_cast
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromScfForOp]
+///----------------------------------------------------------------------------------------
+
+func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
+ %s = math.sqrt %iter : vector<4x1x1x[4]xf32>
+ scf.yield %s : vector<4x1x1x[4]xf32>
+ }
+ return %res : vector<4x1x1x[4]xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
+// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
+// CHECK: scf.yield %[[SQRT]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
+// CHECK: return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
+ %s = math.sqrt %iter : vector<1x1xf32>
+ scf.yield %s : vector<1x1xf32>
+ }
+ return %res : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
+// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
+// CHECK: scf.yield %[[SQRT]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res:3 = scf.for %i = %c0 to %c4 step %c1
+ iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
+ %add = arith.addf %iter0, %iter1 : vector<1x4xf32>
+ scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
+ }
+ return %res#1 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_multiple_operands
+// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
+// CHECK-SAME: %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
+// CHECK-SAME: %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
+// CHECK-DAG: %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
+// CHECK-DAG: %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[LOOP:.+]]:3 = scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
+// CHECK: %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
+// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[CASTBACK]]
|
| /// cast ops inside the block to account for type differences. | ||
| using ValueTypeCastFnTy = | ||
| std::function<Value(OpBuilder &, Location loc, Type, Value)>; | ||
| SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like this helper should go in SCF/Utils/Utils.h, however that would introduce a circular dep because this was split out from a canonicalizer. Any suggestions for a better place for this helper?
MacDue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me, this is helpful for getting rid of casts what otherwise won't fold away 👍
(I actually needed something like this for a downstream experiment a while ago)
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit about a function argument.
This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.
8f3b212 to
b225b34
Compare
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…m#109585) This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.
This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.