- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. #160318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. #160318
Conversation
Signed-off-by: keshavvinayak01 <[email protected]>
| Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using  If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. | 
| @llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Keshav Vinayak Jha (keshavvinayak01) ChangesDescriptionAdds  Full diff: https://github.com/llvm/llvm-project/pull/160318.diff 2 Files Affected: 
 diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 347141e2773b8..4ac61418b97a5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   return success();
 }
 
+/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
+/// vector.to_elements on the source and remapping results according to
+/// broadcast semantics.
+///
+/// Cases handled:
+///  - %x is a scalar: replicate the scalar across all results.
+///  - %x is a vector: create to_elements on source and remap/duplicate results.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+                          SmallVectorImpl<OpFoldResult> &results) {
+  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+  if (!bcastOp)
+    return failure();
+
+  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+  // Bail on scalable vectors.
+  if (resultVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Case 1: scalar broadcast → replicate scalar across all results.
+  if (!isa<VectorType>(bcastOp.getSource().getType())) {
+    Value scalar = bcastOp.getSource();
+    results.assign(resultVecType.getNumElements(), scalar);
+    return success();
+  }
+
+  // Case 2: vector broadcast → create to_elements on source and remap.
+  auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
+  if (srcVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Create a temporary to_elements to get the source elements for mapping.
+  // Change the operand to the broadcast source.
+  OpBuilder builder(toElementsOp);
+  auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
+                                               bcastOp.getSource());
+
+  ArrayRef<int64_t> dstShape = resultVecType.getShape();
+  ArrayRef<int64_t> srcShape = srcVecType.getShape();
+
+  // Quick broadcastability check with right-aligned shapes.
+  unsigned dstRank = dstShape.size();
+  unsigned srcRank = srcShape.size();
+  if (srcRank > dstRank)
+    return failure();
+
+  for (unsigned i = 0; i < dstRank; ++i) {
+    int64_t dstDim = dstShape[i];
+    int64_t srcDim = 1;
+    if (i + srcRank >= dstRank)
+      srcDim = srcShape[i + srcRank - dstRank];
+    if (!(srcDim == 1 || srcDim == dstDim))
+      return failure();
+  }
+
+  int64_t dstCount = 1;
+  for (int64_t v : dstShape)
+    dstCount *= v;
+  results.clear();
+  results.reserve(dstCount);
+
+  // Pre-compute the mapping from destination linear index to source linear index
+  SmallVector<int64_t> dstToSrcMap(dstCount);
+  SmallVector<int64_t> dstIdx(dstShape.size());
+  
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    // Convert linear index to multi-dimensional indices (row-major order)
+    int64_t temp = lin;
+    for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+      int64_t dim = dstShape[i];
+      dstIdx[i] = temp % dim;
+      temp /= dim;
+    }
+    // Right-align mapping from dst indices to src indices.
+    int64_t srcLin = 0;
+    for (unsigned k = 0; k < srcRank; ++k)
+      srcLin = srcLin * srcShape[k] + 
+        ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+    dstToSrcMap[lin] = srcLin;
+  }
+
+  // Apply the pre-computed mapping
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    results.push_back(srcElems.getResult(dstToSrcMap[lin]));
+  }
+  return success();
+}
+
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {
-  return foldToElementsFromElements(*this, results);
+  if (succeeded(foldToElementsFromElements(*this, results)))
+    return success();
+  return foldToElementsOfBroadcast(*this, results);
 }
 
+
 LogicalResult
 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                ToElementsOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 08d28be3f8f73..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3326,6 +3326,32 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
+  %v = vector.broadcast %s : f32 to vector<4xf32>
+  %e:4 = vector.to_elements %v : vector<4xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.to_elements
+  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+  %e:6 = vector.to_elements %v : vector<3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
 // +---------------------------------------------------------------------------
 // Tests for foldFromElementsToConstant
 // +---------------------------------------------------------------------------
 | 
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better comments, removing keywords and verbose docs. 2. Removed redundant "Broadcastability" check, we don't require it since the vector.broadcast op will always be valid when it reaches this logic. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Added better docs for the folder with an example 2. Removed isScalable check, not required for toElementsOp 3. Used free create method for new toElementsOp Signed-off-by: Keshav Vinayak Jha <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It's an interesting canonicalization! A few comments
1. Added better doc for inner broadcast case 2. Added lit test for inner broadcast dim. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit + please remove the #Describtion header from the PR description / commit message -- it's obvious that it's a description.
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
| @keshavvinayak01 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! | 
| This has caused test failures on AArch64: https://lab.llvm.org/buildbot/#/builders/143/builds/11524 One example: The host does have SVE vector extensions, and if this is running code, it's running natively. Also please set your email settings (https://llvm.org/docs/DeveloperPolicy.html#email-addresses) for future contributions so that you are notified. (notifications are only posted on PRs if that PR is the only change included in a build) | 
| Not sure which order these would have happened in. First guess is it failed to process something, so didn't create the file, so it didn't find the file it should have created. | 
| @DavidSpickett I briefly looked at  | 
| Sorry, yes I think you are right. I misread our results page! | 
…ttern rewrite. (#160318) Adds `::fold` for the new `vector.to_elements` op, folding `broadcast` into `to_elements` or no-op wherever possible. --------- Signed-off-by: keshavvinayak01 <[email protected]> Signed-off-by: Keshav Vinayak Jha <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]>
…ttern rewrite. (llvm#160318) Adds `::fold` for the new `vector.to_elements` op, folding `broadcast` into `to_elements` or no-op wherever possible. --------- Signed-off-by: keshavvinayak01 <[email protected]> Signed-off-by: Keshav Vinayak Jha <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]>
Adds
::foldfor the newvector.to_elementsop, foldingbroadcastintoto_elementsor no-op wherever possible.