Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let results = (outs Variadic<AnyType>:$elements);
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Expand Down
120 changes: 119 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

#include <cassert>
#include <cstdint>
#include <numeric>

#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
Expand Down Expand Up @@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}

/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
///
/// Example:
/// %b = vector.broadcast %x : i32 to vector<3xf32>
/// %e:3 = vector.to_elements %b : vector<3xf32>
/// user_op %e#0, %e#1, %e#2
/// becomes:
/// user_op %x, %x, %x
///
/// The vector source case is handled by a canonicalization pattern.
static LogicalResult
foldToElementsOfBroadcast(ToElementsOp toElementsOp,
SmallVectorImpl<OpFoldResult> &results) {
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
if (!bcastOp)
return failure();
// Vectors are handled in the ToElementsOfBroadcast RewritePattern.
if (isa<VectorType>(bcastOp.getSource().getType()))
return failure();

auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());

Value scalar = bcastOp.getSource();
results.assign(resultVecType.getNumElements(), scalar);
return success();
}

LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return foldToElementsFromElements(*this, results);
if (succeeded(foldToElementsFromElements(*this, results)))
return success();
return foldToElementsOfBroadcast(*this, results);
}

LogicalResult
Expand All @@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}

/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
/// vector.
/// - Build `vector.to_elements %v` and remap each destination element to the
/// corresponding source element using broadcast rules (match or 1 →
/// replicate).
///
/// Example:
/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
/// %e:6 = vector.to_elements %v : vector<3x2xf32>
/// becomes:
/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
/// // %src_elems#1, %src_elems#0, %src_elems#1
class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
PatternRewriter &rewriter) const override {
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
if (!bcastOp)
return failure();

// Only handle broadcasts from a vector source here.
auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
if (!srcType)
return failure();

auto dstType = cast<VectorType>(toElementsOp.getSource().getType());

ArrayRef<int64_t> dstShape = dstType.getShape();
ArrayRef<int64_t> srcShape = srcType.getShape();

int64_t dstRank = dstShape.size();
int64_t srcRank = srcShape.size();

// Create elements for the broadcast source vector.
auto srcElems = vector::ToElementsOp::create(
rewriter, toElementsOp.getLoc(), bcastOp.getSource());

int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
std::multiplies<int64_t>());

SmallVector<Value> replacements;
replacements.reserve(dstCount);

// For each element of the destination, determine which element of the
// source should be used. We walk all destination positions using a single
// counter, decode it into per-dimension indices, then build the matching
// source position: use the same index where sizes match, and use 0 where
// the source size is 1 (replication). This mapping is needed so we can
// replace each result of to_elements with the corresponding element from
// the broadcast source.
// Inner-dimension stretch example:
// %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
// %e:12 = vector.to_elements %v : vector<2x3x2xf32>
// becomes:
// %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
// // %src_elems#1, %src_elems#0, %src_elems#1,
// // %src_elems#2, %src_elems#3, %src_elems#2,
// // %src_elems#3, %src_elems#2, %src_elems#3

// Row-major strides for the destination shape.
SmallVector<int64_t> dstStrides = computeStrides(dstShape);
// Row-major strides for the source shape.
SmallVector<int64_t> srcStrides = computeStrides(srcShape);
SmallVector<int64_t> dstIdx(dstRank);
SmallVector<int64_t> srcIdx(srcRank);
for (int64_t lin = 0; lin < dstCount; ++lin) {
// Convert linear destination index to per-dimension indices.
dstIdx = delinearize(lin, dstStrides);
for (int64_t k = 0; k < srcRank; ++k)
srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
// Convert per-dimension source indices back to a linear index.
int64_t srcLin = linearize(srcIdx, srcStrides);
replacements.push_back(srcElems.getResult(srcLin));
}

rewriter.replaceOp(toElementsOp, replacements);
return success();
}
};

void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ToElementsOfBroadcast>(context);
}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3326,6 +3326,46 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x

// -----

// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
%v = vector.broadcast %s : f32 to vector<4xf32>
%e:4 = vector.to_elements %v : vector<4xf32>
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.to_elements
// CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
%v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
%e:6 = vector.to_elements %v : vector<3x2xf32>
// CHECK-NOT: vector.broadcast
// CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
// CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @to_elements_of_vector_broadcast_inner_dim
// CHECK-SAME: (%[[V:.*]]: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast_inner_dim(%v: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
%b = vector.broadcast %v : vector<2x1x2xf32> to vector<2x3x2xf32>
%e:12 = vector.to_elements %b : vector<2x3x2xf32>
// CHECK-NOT: vector.broadcast
// CHECK: %[[SRC:.*]]:4 = vector.to_elements %[[V]] : vector<2x1x2xf32>
// CHECK: return %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5, %e#6, %e#7, %e#8, %e#9, %e#10, %e#11 :
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
}

// -----

// +---------------------------------------------------------------------------
// Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
Expand Down