diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index 644118ca884c6..b62c941797947 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, function_ref bodyBuilder = nullptr); +/// Perform a replacement of one iter OpOperand of an scf.for to the +/// `replacement` value with a different type. A callback is used to insert +/// cast ops inside the block to account for type differences. +using ValueTypeCastFnTy = + llvm::function_ref; +SmallVector replaceAndCastForOpIterArg(RewriterBase &rewriter, + scf::ForOp forOp, + OpOperand &operand, + Value replacement, + const ValueTypeCastFnTy &castFn); + } // namespace scf } // namespace mlir #endif // MLIR_DIALECT_SCF_SCF_H diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 6d47ff3890977..d1c9fd2d217da 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest( }); } +SmallVector +mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, + OpOperand &operand, Value replacement, + const ValueTypeCastFnTy &castFn) { + assert(operand.getOwner() == forOp); + Type oldType = operand.get().getType(), newType = replacement.getType(); + + // 1. Create new iter operands, exactly 1 is replaced. + assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && + "expected an iter OpOperand"); + assert(operand.get().getType() != replacement.getType() && + "Expected a different type"); + SmallVector newIterOperands; + for (OpOperand &opOperand : forOp.getInitArgsMutable()) { + if (opOperand.getOperandNumber() == operand.getOperandNumber()) { + newIterOperands.push_back(replacement); + continue; + } + newIterOperands.push_back(opOperand.get()); + } + + // 2. Create the new forOp shell. + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newIterOperands); + newForOp->setAttrs(forOp->getAttrs()); + Block &newBlock = newForOp.getRegion().front(); + SmallVector newBlockTransferArgs(newBlock.getArguments().begin(), + newBlock.getArguments().end()); + + // 3. Inject an incoming cast op at the beginning of the block for the bbArg + // corresponding to the `replacement` value. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(&newBlock, newBlock.begin()); + BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( + &newForOp->getOpOperand(operand.getOperandNumber())); + Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg); + newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; + + // 4. Steal the old block ops, mapping to the newBlockTransferArgs. + Block &oldBlock = forOp.getRegion().front(); + rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); + + // 5. Inject an outgoing cast op at the end of the block and yield it instead. + auto clonedYieldOp = cast(newBlock.getTerminator()); + rewriter.setInsertionPoint(clonedYieldOp); + unsigned yieldIdx = + newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); + Value castOut = castFn(rewriter, newForOp.getLoc(), newType, + clonedYieldOp.getOperand(yieldIdx)); + SmallVector newYieldOperands = clonedYieldOp.getOperands(); + newYieldOperands[yieldIdx] = castOut; + rewriter.create(newForOp.getLoc(), newYieldOperands); + rewriter.eraseOp(clonedYieldOp); + + // 6. Inject an outgoing cast op after the forOp. + rewriter.setInsertionPointAfter(newForOp); + SmallVector newResults = newForOp.getResults(); + newResults[yieldIdx] = + castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]); + + return newResults; +} + namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. @@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern { } }; -/// Perform a replacement of one iter OpOperand of an scf.for to the -/// `replacement` value which is expected to be the source of a tensor.cast. -/// tensor.cast ops are inserted inside the block to account for the type cast. -static SmallVector -replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, - Value replacement) { - Type oldType = operand.get().getType(), newType = replacement.getType(); - assert(llvm::isa(oldType) && - llvm::isa(newType) && - "expected ranked tensor types"); - - // 1. Create new iter operands, exactly 1 is replaced. - ForOp forOp = cast(operand.getOwner()); - assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && - "expected an iter OpOperand"); - assert(operand.get().getType() != replacement.getType() && - "Expected a different type"); - SmallVector newIterOperands; - for (OpOperand &opOperand : forOp.getInitArgsMutable()) { - if (opOperand.getOperandNumber() == operand.getOperandNumber()) { - newIterOperands.push_back(replacement); - continue; - } - newIterOperands.push_back(opOperand.get()); - } - - // 2. Create the new forOp shell. - scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterOperands); - newForOp->setAttrs(forOp->getAttrs()); - Block &newBlock = newForOp.getRegion().front(); - SmallVector newBlockTransferArgs(newBlock.getArguments().begin(), - newBlock.getArguments().end()); - - // 3. Inject an incoming cast op at the beginning of the block for the bbArg - // corresponding to the `replacement` value. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(&newBlock, newBlock.begin()); - BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( - &newForOp->getOpOperand(operand.getOperandNumber())); - Value castIn = rewriter.create(newForOp.getLoc(), oldType, - newRegionIterArg); - newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; - - // 4. Steal the old block ops, mapping to the newBlockTransferArgs. - Block &oldBlock = forOp.getRegion().front(); - rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); - - // 5. Inject an outgoing cast op at the end of the block and yield it instead. - auto clonedYieldOp = cast(newBlock.getTerminator()); - rewriter.setInsertionPoint(clonedYieldOp); - unsigned yieldIdx = - newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); - Value castOut = rewriter.create( - newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx)); - SmallVector newYieldOperands = clonedYieldOp.getOperands(); - newYieldOperands[yieldIdx] = castOut; - rewriter.create(newForOp.getLoc(), newYieldOperands); - rewriter.eraseOp(clonedYieldOp); - - // 6. Inject an outgoing cast op after the forOp. - rewriter.setInsertionPointAfter(newForOp); - SmallVector newResults = newForOp.getResults(); - newResults[yieldIdx] = rewriter.create( - newForOp.getLoc(), oldType, newResults[yieldIdx]); - - return newResults; -} - /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: /// @@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern { continue; // Create a new ForOp with that iter operand replaced. + ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type, + Value source) { + return b.create(loc, type, source); + }; rewriter.replaceOp( - op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand, - incomingCast.getSource())); + op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand, + incomingCast.getSource(), castFn)); return success(); } return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index ad4e42b31962e..8fcef54f12edf 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1796,6 +1796,62 @@ struct DropUnitDimsFromTransposeOp final } }; +/// A pattern to drop unit dims from the iter_args of an scf.for. +/// +/// Example: +/// +/// BEFORE: +/// ```mlir +/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> { +/// ... +/// scf.yield % +/// } +/// ``` +/// +/// AFTER: +/// ```mlir +/// %drop = vector.shape_cast %init +/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32> +/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> { +/// %new_iter = vector.shape_cast %iter +/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> +/// ... +/// } +/// %res = vector.shape_cast %new_loop +/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> +/// ``` +struct DropUnitDimsFromScfForOp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + /// Find the first iter_arg with droppable unit dims. Further applications + /// of this pattern will apply to later arguments. + for (OpOperand &operand : forOp.getInitArgsMutable()) { + auto vectorType = dyn_cast(operand.get().getType()); + if (!vectorType) + continue; + + VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType); + if (vectorType == newVectorType) + continue; + + // Create a new ForOp with that iter operand replaced. + auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) { + return b.create(loc, type, source); + }; + + Value replacement = + castFn(rewriter, forOp.getLoc(), newVectorType, operand.get()); + rewriter.replaceOp(forOp, + replaceAndCastForOpIterArg(rewriter, forOp, operand, + replacement, castFn)); + return success(); + } + return failure(); + } +}; + /// Pattern to eliminate redundant zero-constants added to reduction operands. /// It's enough for there to be one initial zero value, so we can eliminate the /// extra ones that feed into `vector.reduction `. These get created by the @@ -2001,7 +2057,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, void mlir::vector::populateDropUnitDimWithShapeCastPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); + ShapeCastOpFolder, DropUnitDimsFromScfForOp>( + patterns.getContext(), benefit); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir index af3fc924c1dbe..34a155fbf2fc1 100644 --- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir @@ -207,3 +207,100 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect // CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims // CHECK-NOT: vector.shape_cast + +// ----- + +///---------------------------------------------------------------------------------------- +/// [Pattern: DropUnitDimsFromScfForOp] +///---------------------------------------------------------------------------------------- + +func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> { + %s = math.sqrt %iter : vector<4x1x1x[4]xf32> + scf.yield %s : vector<4x1x1x[4]xf32> + } + return %res : vector<4x1x1x[4]xf32> +} + +// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims +// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32> +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]]) +// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32> +// CHECK: scf.yield %[[SQRT]] +// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32> +// CHECK: return %[[CASTBACK]] + +// ----- + +func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> { + %s = math.sqrt %iter : vector<1x1xf32> + scf.yield %s : vector<1x1xf32> + } + return %res : vector<1x1xf32> +} + +// CHECK-LABEL: func.func @scf_for_with_all_unit_dims +// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32> +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]]) +// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32> +// CHECK: scf.yield %[[SQRT]] +// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32> +// CHECK: return %[[CASTBACK]] + +// ----- + +func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %res:3 = scf.for %i = %c0 to %c4 step %c1 + iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) { + %add = arith.addf %iter0, %iter1 : vector<1x4xf32> + scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32> + } + return %res#1 : vector<1x4xf32> +} + +// CHECK-LABEL: func.func @scf_for_with_multiple_operands +// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32> +// CHECK-SAME: %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32> +// CHECK-DAG: %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32> +// CHECK-DAG: %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[LOOP:.+]]:3 = scf.for +// CHECK-SAME: iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]]) +// CHECK: %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32> +// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]] +// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32> +// CHECK: return %[[CASTBACK]] + +// ----- + +func.func @scf_for_with_scalable_unit_dims(%vec: vector<1x[1]xf32>) -> vector<1x[1]xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x[1]xf32> { + %s = math.sqrt %iter : vector<1x[1]xf32> + scf.yield %s : vector<1x[1]xf32> + } + return %res : vector<1x[1]xf32> +} + +// CHECK-LABEL: func.func @scf_for_with_scalable_unit_dims +// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x[1]xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x[1]xf32> to vector<[1]xf32> +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]]) +// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<[1]xf32> +// CHECK: scf.yield %[[SQRT]] +// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<[1]xf32> to vector<1x[1]xf32> +// CHECK: return %[[CASTBACK]]