Skip to content

Commit a3b34e6

Browse files
authored
[mlir][vector] Add pattern for dropping unit dims from for loops (#109585)
This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.
1 parent e2cc63d commit a3b34e6

File tree

4 files changed

+236
-73
lines changed

4 files changed

+236
-73
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCF.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
107107
function_ref<void(OpBuilder &, Location, ValueRange)>
108108
bodyBuilder = nullptr);
109109

110+
/// Perform a replacement of one iter OpOperand of an scf.for to the
111+
/// `replacement` value with a different type. A callback is used to insert
112+
/// cast ops inside the block to account for type differences.
113+
using ValueTypeCastFnTy =
114+
llvm::function_ref<Value(OpBuilder &, Location loc, Type, Value)>;
115+
SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
116+
scf::ForOp forOp,
117+
OpOperand &operand,
118+
Value replacement,
119+
const ValueTypeCastFnTy &castFn);
120+
110121
} // namespace scf
111122
} // namespace mlir
112123
#endif // MLIR_DIALECT_SCF_SCF_H

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
772772
});
773773
}
774774

775+
SmallVector<Value>
776+
mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
777+
OpOperand &operand, Value replacement,
778+
const ValueTypeCastFnTy &castFn) {
779+
assert(operand.getOwner() == forOp);
780+
Type oldType = operand.get().getType(), newType = replacement.getType();
781+
782+
// 1. Create new iter operands, exactly 1 is replaced.
783+
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784+
"expected an iter OpOperand");
785+
assert(operand.get().getType() != replacement.getType() &&
786+
"Expected a different type");
787+
SmallVector<Value> newIterOperands;
788+
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789+
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790+
newIterOperands.push_back(replacement);
791+
continue;
792+
}
793+
newIterOperands.push_back(opOperand.get());
794+
}
795+
796+
// 2. Create the new forOp shell.
797+
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798+
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799+
forOp.getStep(), newIterOperands);
800+
newForOp->setAttrs(forOp->getAttrs());
801+
Block &newBlock = newForOp.getRegion().front();
802+
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803+
newBlock.getArguments().end());
804+
805+
// 3. Inject an incoming cast op at the beginning of the block for the bbArg
806+
// corresponding to the `replacement` value.
807+
OpBuilder::InsertionGuard g(rewriter);
808+
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
809+
BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810+
&newForOp->getOpOperand(operand.getOperandNumber()));
811+
Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812+
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813+
814+
// 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815+
Block &oldBlock = forOp.getRegion().front();
816+
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817+
818+
// 5. Inject an outgoing cast op at the end of the block and yield it instead.
819+
auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820+
rewriter.setInsertionPoint(clonedYieldOp);
821+
unsigned yieldIdx =
822+
newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823+
Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824+
clonedYieldOp.getOperand(yieldIdx));
825+
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826+
newYieldOperands[yieldIdx] = castOut;
827+
rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828+
rewriter.eraseOp(clonedYieldOp);
829+
830+
// 6. Inject an outgoing cast op after the forOp.
831+
rewriter.setInsertionPointAfter(newForOp);
832+
SmallVector<Value> newResults = newForOp.getResults();
833+
newResults[yieldIdx] =
834+
castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835+
836+
return newResults;
837+
}
838+
775839
namespace {
776840
// Fold away ForOp iter arguments when:
777841
// 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
9731037
}
9741038
};
9751039

976-
/// Perform a replacement of one iter OpOperand of an scf.for to the
977-
/// `replacement` value which is expected to be the source of a tensor.cast.
978-
/// tensor.cast ops are inserted inside the block to account for the type cast.
979-
static SmallVector<Value>
980-
replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
981-
Value replacement) {
982-
Type oldType = operand.get().getType(), newType = replacement.getType();
983-
assert(llvm::isa<RankedTensorType>(oldType) &&
984-
llvm::isa<RankedTensorType>(newType) &&
985-
"expected ranked tensor types");
986-
987-
// 1. Create new iter operands, exactly 1 is replaced.
988-
ForOp forOp = cast<ForOp>(operand.getOwner());
989-
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
990-
"expected an iter OpOperand");
991-
assert(operand.get().getType() != replacement.getType() &&
992-
"Expected a different type");
993-
SmallVector<Value> newIterOperands;
994-
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
995-
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
996-
newIterOperands.push_back(replacement);
997-
continue;
998-
}
999-
newIterOperands.push_back(opOperand.get());
1000-
}
1001-
1002-
// 2. Create the new forOp shell.
1003-
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
1004-
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005-
forOp.getStep(), newIterOperands);
1006-
newForOp->setAttrs(forOp->getAttrs());
1007-
Block &newBlock = newForOp.getRegion().front();
1008-
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
1009-
newBlock.getArguments().end());
1010-
1011-
// 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012-
// corresponding to the `replacement` value.
1013-
OpBuilder::InsertionGuard g(rewriter);
1014-
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
1015-
BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1016-
&newForOp->getOpOperand(operand.getOperandNumber()));
1017-
Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
1018-
newRegionIterArg);
1019-
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
1020-
1021-
// 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022-
Block &oldBlock = forOp.getRegion().front();
1023-
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1024-
1025-
// 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026-
auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1027-
rewriter.setInsertionPoint(clonedYieldOp);
1028-
unsigned yieldIdx =
1029-
newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
1030-
Value castOut = rewriter.create<tensor::CastOp>(
1031-
newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1032-
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
1033-
newYieldOperands[yieldIdx] = castOut;
1034-
rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035-
rewriter.eraseOp(clonedYieldOp);
1036-
1037-
// 6. Inject an outgoing cast op after the forOp.
1038-
rewriter.setInsertionPointAfter(newForOp);
1039-
SmallVector<Value> newResults = newForOp.getResults();
1040-
newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
1041-
newForOp.getLoc(), oldType, newResults[yieldIdx]);
1042-
1043-
return newResults;
1044-
}
1045-
10461040
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
10471041
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
10481042
///
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
10901084
continue;
10911085

10921086
// Create a new ForOp with that iter operand replaced.
1087+
ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
1088+
Value source) {
1089+
return b.create<tensor::CastOp>(loc, type, source);
1090+
};
10931091
rewriter.replaceOp(
1094-
op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095-
incomingCast.getSource()));
1092+
op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
1093+
incomingCast.getSource(), castFn));
10961094
return success();
10971095
}
10981096
return failure();

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1796,6 +1796,62 @@ struct DropUnitDimsFromTransposeOp final
17961796
}
17971797
};
17981798

1799+
/// A pattern to drop unit dims from the iter_args of an scf.for.
1800+
///
1801+
/// Example:
1802+
///
1803+
/// BEFORE:
1804+
/// ```mlir
1805+
/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
1806+
/// ...
1807+
/// scf.yield %
1808+
/// }
1809+
/// ```
1810+
///
1811+
/// AFTER:
1812+
/// ```mlir
1813+
/// %drop = vector.shape_cast %init
1814+
/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
1815+
/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
1816+
/// %new_iter = vector.shape_cast %iter
1817+
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1818+
/// ...
1819+
/// }
1820+
/// %res = vector.shape_cast %new_loop
1821+
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1822+
/// ```
1823+
struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
1824+
using OpRewritePattern::OpRewritePattern;
1825+
1826+
LogicalResult matchAndRewrite(scf::ForOp forOp,
1827+
PatternRewriter &rewriter) const override {
1828+
/// Find the first iter_arg with droppable unit dims. Further applications
1829+
/// of this pattern will apply to later arguments.
1830+
for (OpOperand &operand : forOp.getInitArgsMutable()) {
1831+
auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1832+
if (!vectorType)
1833+
continue;
1834+
1835+
VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1836+
if (vectorType == newVectorType)
1837+
continue;
1838+
1839+
// Create a new ForOp with that iter operand replaced.
1840+
auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
1841+
return b.create<vector::ShapeCastOp>(loc, type, source);
1842+
};
1843+
1844+
Value replacement =
1845+
castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1846+
rewriter.replaceOp(forOp,
1847+
replaceAndCastForOpIterArg(rewriter, forOp, operand,
1848+
replacement, castFn));
1849+
return success();
1850+
}
1851+
return failure();
1852+
}
1853+
};
1854+
17991855
/// Pattern to eliminate redundant zero-constants added to reduction operands.
18001856
/// It's enough for there to be one initial zero value, so we can eliminate the
18011857
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -2001,7 +2057,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
20012057
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
20022058
RewritePatternSet &patterns, PatternBenefit benefit) {
20032059
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
2004-
ShapeCastOpFolder>(patterns.getContext(), benefit);
2060+
ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
2061+
patterns.getContext(), benefit);
20052062
}
20062063

20072064
void mlir::vector::populateBubbleVectorBitCastOpPatterns(

mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,100 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect
207207

208208
// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
209209
// CHECK-NOT: vector.shape_cast
210+
211+
// -----
212+
213+
///----------------------------------------------------------------------------------------
214+
/// [Pattern: DropUnitDimsFromScfForOp]
215+
///----------------------------------------------------------------------------------------
216+
217+
func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
218+
%c0 = arith.constant 0 : index
219+
%c1 = arith.constant 1 : index
220+
%c4 = arith.constant 4 : index
221+
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
222+
%s = math.sqrt %iter : vector<4x1x1x[4]xf32>
223+
scf.yield %s : vector<4x1x1x[4]xf32>
224+
}
225+
return %res : vector<4x1x1x[4]xf32>
226+
}
227+
228+
// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
229+
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
230+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
231+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
232+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
233+
// CHECK: scf.yield %[[SQRT]]
234+
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
235+
// CHECK: return %[[CASTBACK]]
236+
237+
// -----
238+
239+
func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
240+
%c0 = arith.constant 0 : index
241+
%c1 = arith.constant 1 : index
242+
%c4 = arith.constant 4 : index
243+
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
244+
%s = math.sqrt %iter : vector<1x1xf32>
245+
scf.yield %s : vector<1x1xf32>
246+
}
247+
return %res : vector<1x1xf32>
248+
}
249+
250+
// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
251+
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
252+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
253+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
254+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
255+
// CHECK: scf.yield %[[SQRT]]
256+
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
257+
// CHECK: return %[[CASTBACK]]
258+
259+
// -----
260+
261+
func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
262+
%c0 = arith.constant 0 : index
263+
%c1 = arith.constant 1 : index
264+
%c4 = arith.constant 4 : index
265+
%res:3 = scf.for %i = %c0 to %c4 step %c1
266+
iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
267+
%add = arith.addf %iter0, %iter1 : vector<1x4xf32>
268+
scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
269+
}
270+
return %res#1 : vector<1x4xf32>
271+
}
272+
273+
// CHECK-LABEL: func.func @scf_for_with_multiple_operands
274+
// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
275+
// CHECK-SAME: %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
276+
// CHECK-SAME: %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
277+
// CHECK-DAG: %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
278+
// CHECK-DAG: %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
279+
// CHECK: %[[LOOP:.+]]:3 = scf.for
280+
// CHECK-SAME: iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
281+
// CHECK: %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
282+
// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
283+
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
284+
// CHECK: return %[[CASTBACK]]
285+
286+
// -----
287+
288+
func.func @scf_for_with_scalable_unit_dims(%vec: vector<1x[1]xf32>) -> vector<1x[1]xf32> {
289+
%c0 = arith.constant 0 : index
290+
%c1 = arith.constant 1 : index
291+
%c4 = arith.constant 4 : index
292+
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x[1]xf32> {
293+
%s = math.sqrt %iter : vector<1x[1]xf32>
294+
scf.yield %s : vector<1x[1]xf32>
295+
}
296+
return %res : vector<1x[1]xf32>
297+
}
298+
299+
// CHECK-LABEL: func.func @scf_for_with_scalable_unit_dims
300+
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x[1]xf32>
301+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x[1]xf32> to vector<[1]xf32>
302+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
303+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<[1]xf32>
304+
// CHECK: scf.yield %[[SQRT]]
305+
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<[1]xf32> to vector<1x[1]xf32>
306+
// CHECK: return %[[CASTBACK]]

0 commit comments

Comments
 (0)