Skip to content

Commit 805451f

Browse files
[MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (#160318)
Adds `::fold` for the new `vector.to_elements` op, folding `broadcast` into `to_elements` or no-op wherever possible. --------- Signed-off-by: keshavvinayak01 <[email protected]> Signed-off-by: Keshav Vinayak Jha <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]>
1 parent c7fbe38 commit 805451f

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
806806
let results = (outs Variadic<AnyType>:$elements);
807807
let assemblyFormat = "$source attr-dict `:` type($source)";
808808
let hasFolder = 1;
809+
let hasCanonicalizer = 1;
809810
}
810811

811812
def Vector_FromElementsOp : Vector_Op<"from_elements", [

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
#include <cassert>
4949
#include <cstdint>
50+
#include <numeric>
5051

5152
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
5253
// Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
24122413
return success();
24132414
}
24142415

2416+
/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2417+
///
2418+
/// Example:
2419+
/// %b = vector.broadcast %x : i32 to vector<3xf32>
2420+
/// %e:3 = vector.to_elements %b : vector<3xf32>
2421+
/// user_op %e#0, %e#1, %e#2
2422+
/// becomes:
2423+
/// user_op %x, %x, %x
2424+
///
2425+
/// The vector source case is handled by a canonicalization pattern.
2426+
static LogicalResult
2427+
foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2428+
SmallVectorImpl<OpFoldResult> &results) {
2429+
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2430+
if (!bcastOp)
2431+
return failure();
2432+
// Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2433+
if (isa<VectorType>(bcastOp.getSource().getType()))
2434+
return failure();
2435+
2436+
auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2437+
2438+
Value scalar = bcastOp.getSource();
2439+
results.assign(resultVecType.getNumElements(), scalar);
2440+
return success();
2441+
}
2442+
24152443
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
24162444
SmallVectorImpl<OpFoldResult> &results) {
2417-
return foldToElementsFromElements(*this, results);
2445+
if (succeeded(foldToElementsFromElements(*this, results)))
2446+
return success();
2447+
return foldToElementsOfBroadcast(*this, results);
24182448
}
24192449

24202450
LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
24272457
return success();
24282458
}
24292459

2460+
/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2461+
/// vector.
2462+
/// - Build `vector.to_elements %v` and remap each destination element to the
2463+
/// corresponding source element using broadcast rules (match or 1 →
2464+
/// replicate).
2465+
///
2466+
/// Example:
2467+
/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2468+
/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2469+
/// becomes:
2470+
/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2471+
/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2472+
/// // %src_elems#1, %src_elems#0, %src_elems#1
2473+
struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2474+
using Base::Base;
2475+
2476+
LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2477+
PatternRewriter &rewriter) const override {
2478+
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2479+
if (!bcastOp)
2480+
return failure();
2481+
2482+
// Only handle broadcasts from a vector source here.
2483+
auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2484+
if (!srcType)
2485+
return failure();
2486+
2487+
auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2488+
2489+
ArrayRef<int64_t> dstShape = dstType.getShape();
2490+
ArrayRef<int64_t> srcShape = srcType.getShape();
2491+
2492+
int64_t dstRank = dstShape.size();
2493+
int64_t srcRank = srcShape.size();
2494+
2495+
// Create elements for the broadcast source vector.
2496+
auto srcElems = vector::ToElementsOp::create(
2497+
rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2498+
2499+
int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
2500+
std::multiplies<int64_t>());
2501+
2502+
SmallVector<Value> replacements;
2503+
replacements.reserve(dstCount);
2504+
2505+
// For each element of the destination, determine which element of the
2506+
// source should be used. We walk all destination positions using a single
2507+
// counter, decode it into per-dimension indices, then build the matching
2508+
// source position: use the same index where sizes match, and use 0 where
2509+
// the source size is 1 (replication). This mapping is needed so we can
2510+
// replace each result of to_elements with the corresponding element from
2511+
// the broadcast source.
2512+
// Inner-dimension stretch example:
2513+
// %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2514+
// %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2515+
// becomes:
2516+
// %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2517+
// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2518+
// // %src_elems#1, %src_elems#0, %src_elems#1,
2519+
// // %src_elems#2, %src_elems#3, %src_elems#2,
2520+
// // %src_elems#3, %src_elems#2, %src_elems#3
2521+
2522+
// Row-major strides for the destination shape.
2523+
SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2524+
// Row-major strides for the source shape.
2525+
SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2526+
SmallVector<int64_t> dstIdx(dstRank);
2527+
SmallVector<int64_t> srcIdx(srcRank);
2528+
for (int64_t lin = 0; lin < dstCount; ++lin) {
2529+
// Convert linear destination index to per-dimension indices.
2530+
dstIdx = delinearize(lin, dstStrides);
2531+
for (int64_t k = 0; k < srcRank; ++k)
2532+
srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2533+
// Convert per-dimension source indices back to a linear index.
2534+
int64_t srcLin = linearize(srcIdx, srcStrides);
2535+
replacements.push_back(srcElems.getResult(srcLin));
2536+
}
2537+
2538+
rewriter.replaceOp(toElementsOp, replacements);
2539+
return success();
2540+
}
2541+
};
2542+
2543+
void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2544+
MLIRContext *context) {
2545+
results.add<ToElementsOfBroadcast>(context);
2546+
}
2547+
24302548
//===----------------------------------------------------------------------===//
24312549
// FromElementsOp
24322550
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,6 +3322,46 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
33223322

33233323
// -----
33243324

3325+
// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
3326+
// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
3327+
func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
3328+
%v = vector.broadcast %s : f32 to vector<4xf32>
3329+
%e:4 = vector.to_elements %v : vector<4xf32>
3330+
// CHECK-NOT: vector.broadcast
3331+
// CHECK-NOT: vector.to_elements
3332+
// CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
3333+
return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
3334+
}
3335+
3336+
// -----
3337+
3338+
// CHECK-LABEL: func @to_elements_of_vector_broadcast
3339+
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
3340+
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
3341+
%v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
3342+
%e:6 = vector.to_elements %v : vector<3x2xf32>
3343+
// CHECK-NOT: vector.broadcast
3344+
// CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
3345+
// CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
3346+
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
3347+
}
3348+
3349+
// -----
3350+
3351+
// CHECK-LABEL: func @to_elements_of_vector_broadcast_inner_dim
3352+
// CHECK-SAME: (%[[V:.*]]: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)
3353+
func.func @to_elements_of_vector_broadcast_inner_dim(%v: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
3354+
%b = vector.broadcast %v : vector<2x1x2xf32> to vector<2x3x2xf32>
3355+
%e:12 = vector.to_elements %b : vector<2x3x2xf32>
3356+
// CHECK-NOT: vector.broadcast
3357+
// CHECK: %[[SRC:.*]]:4 = vector.to_elements %[[V]] : vector<2x1x2xf32>
3358+
// CHECK: return %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3
3359+
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5, %e#6, %e#7, %e#8, %e#9, %e#10, %e#11 :
3360+
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
3361+
}
3362+
3363+
// -----
3364+
33253365
// +---------------------------------------------------------------------------
33263366
// Tests for foldFromElementsToConstant
33273367
// +---------------------------------------------------------------------------

0 commit comments

Comments
 (0)