Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let results = (outs Variadic<AnyType>:$elements);
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Expand Down
120 changes: 119 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

#include <cassert>
#include <cstdint>
#include <numeric>

#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
Expand Down Expand Up @@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}

/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
///
/// Example:
/// %b = vector.broadcast %x : i32 to vector<3xf32>
/// %e:3 = vector.to_elements %b : vector<3xf32>
/// user_op %e#0, %e#1, %e#2
/// becomes:
/// user_op %x, %x, %x
///
/// The vector source case is handled by a canonicalization pattern.
static LogicalResult
foldToElementsOfBroadcast(ToElementsOp toElementsOp,
SmallVectorImpl<OpFoldResult> &results) {
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
if (!bcastOp)
return failure();
// Vectors are handled in the ToElementsOfBroadcast RewritePattern.
if (isa<VectorType>(bcastOp.getSource().getType()))
return failure();

auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());

Value scalar = bcastOp.getSource();
results.assign(resultVecType.getNumElements(), scalar);
return success();
}

LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return foldToElementsFromElements(*this, results);
if (succeeded(foldToElementsFromElements(*this, results)))
return success();
return foldToElementsOfBroadcast(*this, results);
}

LogicalResult
Expand All @@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}

/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
/// vector.
/// - Build `vector.to_elements %v` and remap each destination element to the
/// corresponding source element using broadcast rules (match or 1 →
/// replicate).
///
/// Example:
/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
/// %e:6 = vector.to_elements %v : vector<3x2xf32>
/// becomes:
/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
/// // %src_elems#1, %src_elems#0, %src_elems#1
struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
using Base::Base;

LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
PatternRewriter &rewriter) const override {
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
if (!bcastOp)
return failure();

// Only handle broadcasts from a vector source here.
auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
if (!srcType)
return failure();

auto dstType = cast<VectorType>(toElementsOp.getSource().getType());

ArrayRef<int64_t> dstShape = dstType.getShape();
ArrayRef<int64_t> srcShape = srcType.getShape();

int64_t dstRank = dstShape.size();
int64_t srcRank = srcShape.size();

// Create elements for the broadcast source vector.
auto srcElems = vector::ToElementsOp::create(
rewriter, toElementsOp.getLoc(), bcastOp.getSource());

int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
std::multiplies<int64_t>());

SmallVector<Value> replacements;
replacements.reserve(dstCount);

// For each element of the destination, determine which element of the
// source should be used. We walk all destination positions using a single
// counter, decode it into per-dimension indices, then build the matching
// source position: use the same index where sizes match, and use 0 where
// the source size is 1 (replication). This mapping is needed so we can
// replace each result of to_elements with the corresponding element from
// the broadcast source.
// Inner-dimension stretch example:
// %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
// %e:12 = vector.to_elements %v : vector<2x3x2xf32>
// becomes:
// %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
// // %src_elems#1, %src_elems#0, %src_elems#1,
// // %src_elems#2, %src_elems#3, %src_elems#2,
// // %src_elems#3, %src_elems#2, %src_elems#3

// Row-major strides for the destination shape.
SmallVector<int64_t> dstStrides = computeStrides(dstShape);
// Row-major strides for the source shape.
SmallVector<int64_t> srcStrides = computeStrides(srcShape);
SmallVector<int64_t> dstIdx(dstRank);
SmallVector<int64_t> srcIdx(srcRank);
for (int64_t lin = 0; lin < dstCount; ++lin) {
// Convert linear destination index to per-dimension indices.
dstIdx = delinearize(lin, dstStrides);
for (int64_t k = 0; k < srcRank; ++k)
srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
// Convert per-dimension source indices back to a linear index.
int64_t srcLin = linearize(srcIdx, srcStrides);
replacements.push_back(srcElems.getResult(srcLin));
}

rewriter.replaceOp(toElementsOp, replacements);
return success();
}
};

void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ToElementsOfBroadcast>(context);
}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3326,6 +3326,46 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x

// -----

// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
%v = vector.broadcast %s : f32 to vector<4xf32>
%e:4 = vector.to_elements %v : vector<4xf32>
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.to_elements
// CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
%v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
%e:6 = vector.to_elements %v : vector<3x2xf32>
// CHECK-NOT: vector.broadcast
// CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
// CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @to_elements_of_vector_broadcast_inner_dim
// CHECK-SAME: (%[[V:.*]]: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast_inner_dim(%v: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
%b = vector.broadcast %v : vector<2x1x2xf32> to vector<2x3x2xf32>
%e:12 = vector.to_elements %b : vector<2x3x2xf32>
// CHECK-NOT: vector.broadcast
// CHECK: %[[SRC:.*]]:4 = vector.to_elements %[[V]] : vector<2x1x2xf32>
// CHECK: return %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5, %e#6, %e#7, %e#8, %e#9, %e#10, %e#11 :
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
}

// -----

// +---------------------------------------------------------------------------
// Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
Expand Down