Skip to content

Conversation

@qedawkins
Copy link
Contributor

This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: Quinn Dawkins (qedawkins)

Changes

This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.


Full diff: https://github.com/llvm/llvm-project/pull/109585.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+11)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+70-72)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+59-1)
  • (modified) mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir (+75)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 644118ca884c6b..d89d566ece62c1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
                        function_ref<void(OpBuilder &, Location, ValueRange)>
                            bodyBuilder = nullptr);
 
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value with a different type. A callback is used to insert
+/// cast ops inside the block to account for type differences.
+using ValueTypeCastFnTy =
+    std::function<Value(OpBuilder &, Location loc, Type, Value)>;
+SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
+                                              scf::ForOp forOp,
+                                              OpOperand &operand,
+                                              Value replacement,
+                                              const ValueTypeCastFnTy &castFn);
+
 } // namespace scf
 } // namespace mlir
 #endif // MLIR_DIALECT_SCF_SCF_H
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..d1c9fd2d217dad 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
                        });
 }
 
+SmallVector<Value>
+mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
+                                      OpOperand &operand, Value replacement,
+                                      const ValueTypeCastFnTy &castFn) {
+  assert(operand.getOwner() == forOp);
+  Type oldType = operand.get().getType(), newType = replacement.getType();
+
+  // 1. Create new iter operands, exactly 1 is replaced.
+  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+         "expected an iter OpOperand");
+  assert(operand.get().getType() != replacement.getType() &&
+         "Expected a different type");
+  SmallVector<Value> newIterOperands;
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
+    if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+      newIterOperands.push_back(replacement);
+      continue;
+    }
+    newIterOperands.push_back(opOperand.get());
+  }
+
+  // 2. Create the new forOp shell.
+  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+      forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+      forOp.getStep(), newIterOperands);
+  newForOp->setAttrs(forOp->getAttrs());
+  Block &newBlock = newForOp.getRegion().front();
+  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+                                             newBlock.getArguments().end());
+
+  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+  // corresponding to the `replacement` value.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+      &newForOp->getOpOperand(operand.getOperandNumber()));
+  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
+  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+  Block &oldBlock = forOp.getRegion().front();
+  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+  rewriter.setInsertionPoint(clonedYieldOp);
+  unsigned yieldIdx =
+      newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
+                         clonedYieldOp.getOperand(yieldIdx));
+  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+  newYieldOperands[yieldIdx] = castOut;
+  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+  rewriter.eraseOp(clonedYieldOp);
+
+  // 6. Inject an outgoing cast op after the forOp.
+  rewriter.setInsertionPointAfter(newForOp);
+  SmallVector<Value> newResults = newForOp.getResults();
+  newResults[yieldIdx] =
+      castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+  return newResults;
+}
+
 namespace {
 // Fold away ForOp iter arguments when:
 // 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
   }
 };
 
-/// Perform a replacement of one iter OpOperand of an scf.for to the
-/// `replacement` value which is expected to be the source of a tensor.cast.
-/// tensor.cast ops are inserted inside the block to account for the type cast.
-static SmallVector<Value>
-replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
-                              Value replacement) {
-  Type oldType = operand.get().getType(), newType = replacement.getType();
-  assert(llvm::isa<RankedTensorType>(oldType) &&
-         llvm::isa<RankedTensorType>(newType) &&
-         "expected ranked tensor types");
-
-  // 1. Create new iter operands, exactly 1 is replaced.
-  ForOp forOp = cast<ForOp>(operand.getOwner());
-  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
-         "expected an iter OpOperand");
-  assert(operand.get().getType() != replacement.getType() &&
-         "Expected a different type");
-  SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
-    if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
-      newIterOperands.push_back(replacement);
-      continue;
-    }
-    newIterOperands.push_back(opOperand.get());
-  }
-
-  // 2. Create the new forOp shell.
-  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
-      forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-      forOp.getStep(), newIterOperands);
-  newForOp->setAttrs(forOp->getAttrs());
-  Block &newBlock = newForOp.getRegion().front();
-  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
-                                             newBlock.getArguments().end());
-
-  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
-  // corresponding to the `replacement` value.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
-  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
-      &newForOp->getOpOperand(operand.getOperandNumber()));
-  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
-                                                 newRegionIterArg);
-  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
-
-  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
-  Block &oldBlock = forOp.getRegion().front();
-  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-
-  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
-  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-  rewriter.setInsertionPoint(clonedYieldOp);
-  unsigned yieldIdx =
-      newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
-  Value castOut = rewriter.create<tensor::CastOp>(
-      newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
-  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
-  newYieldOperands[yieldIdx] = castOut;
-  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
-  rewriter.eraseOp(clonedYieldOp);
-
-  // 6. Inject an outgoing cast op after the forOp.
-  rewriter.setInsertionPointAfter(newForOp);
-  SmallVector<Value> newResults = newForOp.getResults();
-  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
-      newForOp.getLoc(), oldType, newResults[yieldIdx]);
-
-  return newResults;
-}
-
 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
 ///
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
         continue;
 
       // Create a new ForOp with that iter operand replaced.
+      ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
+                                    Value source) {
+        return b.create<tensor::CastOp>(loc, type, source);
+      };
       rewriter.replaceOp(
-          op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
-                                            incomingCast.getSource()));
+          op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
+                                         incomingCast.getSource(), castFn));
       return success();
     }
     return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad4e42b31962e1..ba32583fc3cdc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1796,6 +1796,63 @@ struct DropUnitDimsFromTransposeOp final
   }
 };
 
+/// A pattern to drop unit dims from the iter_args of an scf.for.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
+///    ...
+///    scf.yield %
+///  }
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %drop = vector.shape_cast %init
+///    : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+///  %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
+///    %new_iter = vector.shape_cast %iter
+///      : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///    ...
+///  }
+///  %res = vector.shape_cast %new_loop
+///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///  ```
+struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    /// Find the first iter_arg with droppable unit dims. Further applications
+    /// of this pattern will apply to later arguments.
+    for (OpOperand &operand : forOp.getInitArgsMutable()) {
+      auto vectorType = dyn_cast<VectorType>(operand.get().getType());
+      if (!vectorType)
+        continue;
+
+      VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
+      if (vectorType == newVectorType)
+        continue;
+
+      // Create a new ForOp with that iter operand replaced.
+      mlir::scf::ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc,
+                                               Type type, Value source) {
+        return b.create<vector::ShapeCastOp>(loc, type, source);
+      };
+
+      Value replacement =
+          castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
+      rewriter.replaceOp(forOp,
+                         replaceAndCastForOpIterArg(rewriter, forOp, operand,
+                                                    replacement, castFn));
+      return success();
+    }
+    return failure();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -2001,7 +2058,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
-               ShapeCastOpFolder>(patterns.getContext(), benefit);
+               ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index af3fc924c1dbe7..8249400a43c757 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -207,3 +207,78 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect
 
 // CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
 // CHECK-NOT: vector.shape_cast
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromScfForOp]
+///----------------------------------------------------------------------------------------
+
+func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
+    %s = math.sqrt %iter : vector<4x1x1x[4]xf32>
+    scf.yield %s : vector<4x1x1x[4]xf32>
+  }
+  return %res : vector<4x1x1x[4]xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
+//  CHECK-SAME:   %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+//       CHECK:   %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
+//       CHECK:     scf.yield %[[SQRT]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
+//       CHECK:   return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
+    %s = math.sqrt %iter : vector<1x1xf32>
+    scf.yield %s : vector<1x1xf32>
+  }
+  return %res : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
+//  CHECK-SAME:   %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
+//       CHECK:   %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
+//       CHECK:     scf.yield %[[SQRT]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res:3 = scf.for %i = %c0 to %c4 step %c1
+    iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
+    %add = arith.addf %iter0, %iter1 : vector<1x4xf32>
+    scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
+  }
+  return %res#1 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_multiple_operands
+//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
+//  CHECK-SAME:   %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
+//  CHECK-SAME:   %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
+//   CHECK-DAG:   %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
+//   CHECK-DAG:   %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
+//       CHECK:   %[[LOOP:.+]]:3 = scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
+//       CHECK:     %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
+//       CHECK:     scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
+//       CHECK:   return %[[CASTBACK]]

/// cast ops inside the block to account for type differences.
using ValueTypeCastFnTy =
std::function<Value(OpBuilder &, Location loc, Type, Value)>;
SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this helper should go in SCF/Utils/Utils.h, however that would introduce a circular dep because this was split out from a canonicalizer. Any suggestions for a better place for this helper?

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, this is helpful for getting rid of casts what otherwise won't fold away 👍
(I actually needed something like this for a downstream experiment a while ago)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one nit about a function argument.

This adds a pattern for dropping unit dims from the iter_args of scf.for
ops using vector.shape_cast. This composes with the other patterns for
dropping unit dims from elementwise ops and transposes.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@qedawkins qedawkins merged commit a3b34e6 into llvm:main Sep 27, 2024
4 checks passed
@qedawkins qedawkins deleted the drop_for_unit_dims branch September 27, 2024 15:43
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
…m#109585)

This adds a pattern for dropping unit dims from the iter_args of scf.for
ops using vector.shape_cast. This composes with the other patterns for
dropping unit dims from elementwise ops and transposes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants