Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
function_ref<void(OpBuilder &, Location, ValueRange)>
bodyBuilder = nullptr);

/// Perform a replacement of one iter OpOperand of an scf.for to the
/// `replacement` value with a different type. A callback is used to insert
/// cast ops inside the block to account for type differences.
using ValueTypeCastFnTy =
llvm::function_ref<Value(OpBuilder &, Location loc, Type, Value)>;
SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this helper should go in SCF/Utils/Utils.h, however that would introduce a circular dep because this was split out from a canonicalizer. Any suggestions for a better place for this helper?

scf::ForOp forOp,
OpOperand &operand,
Value replacement,
const ValueTypeCastFnTy &castFn);

} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_SCF_SCF_H
142 changes: 70 additions & 72 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
});
}

SmallVector<Value>
mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
OpOperand &operand, Value replacement,
const ValueTypeCastFnTy &castFn) {
assert(operand.getOwner() == forOp);
Type oldType = operand.get().getType(), newType = replacement.getType();

// 1. Create new iter operands, exactly 1 is replaced.
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
"expected an iter OpOperand");
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
}
newIterOperands.push_back(opOperand.get());
}

// 2. Create the new forOp shell.
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterOperands);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
newBlock.getArguments().end());

// 3. Inject an incoming cast op at the beginning of the block for the bbArg
// corresponding to the `replacement` value.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
&newForOp->getOpOperand(operand.getOperandNumber()));
Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;

// 4. Steal the old block ops, mapping to the newBlockTransferArgs.
Block &oldBlock = forOp.getRegion().front();
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);

// 5. Inject an outgoing cast op at the end of the block and yield it instead.
auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
rewriter.setInsertionPoint(clonedYieldOp);
unsigned yieldIdx =
newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
clonedYieldOp.getOperand(yieldIdx));
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
newYieldOperands[yieldIdx] = castOut;
rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
rewriter.eraseOp(clonedYieldOp);

// 6. Inject an outgoing cast op after the forOp.
rewriter.setInsertionPointAfter(newForOp);
SmallVector<Value> newResults = newForOp.getResults();
newResults[yieldIdx] =
castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);

return newResults;
}

namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
Expand Down Expand Up @@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
}
};

/// Perform a replacement of one iter OpOperand of an scf.for to the
/// `replacement` value which is expected to be the source of a tensor.cast.
/// tensor.cast ops are inserted inside the block to account for the type cast.
static SmallVector<Value>
replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
Value replacement) {
Type oldType = operand.get().getType(), newType = replacement.getType();
assert(llvm::isa<RankedTensorType>(oldType) &&
llvm::isa<RankedTensorType>(newType) &&
"expected ranked tensor types");

// 1. Create new iter operands, exactly 1 is replaced.
ForOp forOp = cast<ForOp>(operand.getOwner());
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
"expected an iter OpOperand");
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
}
newIterOperands.push_back(opOperand.get());
}

// 2. Create the new forOp shell.
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterOperands);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
newBlock.getArguments().end());

// 3. Inject an incoming cast op at the beginning of the block for the bbArg
// corresponding to the `replacement` value.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
&newForOp->getOpOperand(operand.getOperandNumber()));
Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
newRegionIterArg);
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;

// 4. Steal the old block ops, mapping to the newBlockTransferArgs.
Block &oldBlock = forOp.getRegion().front();
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);

// 5. Inject an outgoing cast op at the end of the block and yield it instead.
auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
rewriter.setInsertionPoint(clonedYieldOp);
unsigned yieldIdx =
newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
Value castOut = rewriter.create<tensor::CastOp>(
newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
newYieldOperands[yieldIdx] = castOut;
rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
rewriter.eraseOp(clonedYieldOp);

// 6. Inject an outgoing cast op after the forOp.
rewriter.setInsertionPointAfter(newForOp);
SmallVector<Value> newResults = newForOp.getResults();
newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
newForOp.getLoc(), oldType, newResults[yieldIdx]);

return newResults;
}

/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
///
Expand Down Expand Up @@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
continue;

// Create a new ForOp with that iter operand replaced.
ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
Value source) {
return b.create<tensor::CastOp>(loc, type, source);
};
rewriter.replaceOp(
op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
incomingCast.getSource()));
op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
incomingCast.getSource(), castFn));
return success();
}
return failure();
Expand Down
59 changes: 58 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,62 @@ struct DropUnitDimsFromTransposeOp final
}
};

/// A pattern to drop unit dims from the iter_args of an scf.for.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
/// ...
/// scf.yield %
/// }
/// ```
///
/// AFTER:
/// ```mlir
/// %drop = vector.shape_cast %init
/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
/// %new_iter = vector.shape_cast %iter
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
/// ...
/// }
/// %res = vector.shape_cast %new_loop
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
/// ```
struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
/// Find the first iter_arg with droppable unit dims. Further applications
/// of this pattern will apply to later arguments.
for (OpOperand &operand : forOp.getInitArgsMutable()) {
auto vectorType = dyn_cast<VectorType>(operand.get().getType());
if (!vectorType)
continue;

VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
if (vectorType == newVectorType)
continue;

// Create a new ForOp with that iter operand replaced.
auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
return b.create<vector::ShapeCastOp>(loc, type, source);
};

Value replacement =
castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
rewriter.replaceOp(forOp,
replaceAndCastForOpIterArg(rewriter, forOp, operand,
replacement, castFn));
return success();
}
return failure();
}
};

/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
Expand Down Expand Up @@ -2001,7 +2057,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
ShapeCastOpFolder>(patterns.getContext(), benefit);
ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
patterns.getContext(), benefit);
}

void mlir::vector::populateBubbleVectorBitCastOpPatterns(
Expand Down
97 changes: 97 additions & 0 deletions mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,100 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect

// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
// CHECK-NOT: vector.shape_cast

// -----

///----------------------------------------------------------------------------------------
/// [Pattern: DropUnitDimsFromScfForOp]
///----------------------------------------------------------------------------------------

func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
%s = math.sqrt %iter : vector<4x1x1x[4]xf32>
scf.yield %s : vector<4x1x1x[4]xf32>
}
return %res : vector<4x1x1x[4]xf32>
}

// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
// CHECK: scf.yield %[[SQRT]]
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
// CHECK: return %[[CASTBACK]]

// -----

func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
%s = math.sqrt %iter : vector<1x1xf32>
scf.yield %s : vector<1x1xf32>
}
return %res : vector<1x1xf32>
}

// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
// CHECK: scf.yield %[[SQRT]]
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[CASTBACK]]

// -----

func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%res:3 = scf.for %i = %c0 to %c4 step %c1
iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
%add = arith.addf %iter0, %iter1 : vector<1x4xf32>
scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
}
return %res#1 : vector<1x4xf32>
}

// CHECK-LABEL: func.func @scf_for_with_multiple_operands
// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
// CHECK-SAME: %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
// CHECK-SAME: %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
// CHECK-DAG: %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
// CHECK-DAG: %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
// CHECK: %[[LOOP:.+]]:3 = scf.for
// CHECK-SAME: iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
// CHECK: %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[CASTBACK]]

// -----

func.func @scf_for_with_scalable_unit_dims(%vec: vector<1x[1]xf32>) -> vector<1x[1]xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x[1]xf32> {
%s = math.sqrt %iter : vector<1x[1]xf32>
scf.yield %s : vector<1x[1]xf32>
}
return %res : vector<1x[1]xf32>
}

// CHECK-LABEL: func.func @scf_for_with_scalable_unit_dims
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x[1]xf32>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x[1]xf32> to vector<[1]xf32>
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<[1]xf32>
// CHECK: scf.yield %[[SQRT]]
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<[1]xf32> to vector<1x[1]xf32>
// CHECK: return %[[CASTBACK]]
Loading