Skip to content
Merged
94 changes: 93 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}

/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
/// vector.to_elements on the source and remapping results according to
/// broadcast semantics.
///
/// Cases handled:
/// - %x is a scalar: replicate the scalar across all results.
/// - %x is a vector: create to_elements on source and remap/duplicate results.
static LogicalResult
foldToElementsOfBroadcast(ToElementsOp toElementsOp,
SmallVectorImpl<OpFoldResult> &results) {
auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
if (!bcastOp)
return failure();

auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
// Bail on scalable vectors.
if (resultVecType.getNumScalableDims() != 0)
return failure();

// Case 1: scalar broadcast → replicate scalar across all results.
if (!isa<VectorType>(bcastOp.getSource().getType())) {
Value scalar = bcastOp.getSource();
results.assign(resultVecType.getNumElements(), scalar);
return success();
}

// Case 2: vector broadcast → create to_elements on source and remap.
auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
if (srcVecType.getNumScalableDims() != 0)
return failure();

// Create a temporary to_elements to get the source elements for mapping.
// Change the operand to the broadcast source.
OpBuilder builder(toElementsOp);
auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
bcastOp.getSource());

ArrayRef<int64_t> dstShape = resultVecType.getShape();
ArrayRef<int64_t> srcShape = srcVecType.getShape();

// Quick broadcastability check with right-aligned shapes.
unsigned dstRank = dstShape.size();
unsigned srcRank = srcShape.size();
if (srcRank > dstRank)
return failure();

for (unsigned i = 0; i < dstRank; ++i) {
int64_t dstDim = dstShape[i];
int64_t srcDim = 1;
if (i + srcRank >= dstRank)
srcDim = srcShape[i + srcRank - dstRank];
if (!(srcDim == 1 || srcDim == dstDim))
return failure();
}

int64_t dstCount = 1;
for (int64_t v : dstShape)
dstCount *= v;
results.clear();
results.reserve(dstCount);

// Pre-compute the mapping from destination linear index to source linear index
SmallVector<int64_t> dstToSrcMap(dstCount);
SmallVector<int64_t> dstIdx(dstShape.size());

for (int64_t lin = 0; lin < dstCount; ++lin) {
// Convert linear index to multi-dimensional indices (row-major order)
int64_t temp = lin;
for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
int64_t dim = dstShape[i];
dstIdx[i] = temp % dim;
temp /= dim;
}
// Right-align mapping from dst indices to src indices.
int64_t srcLin = 0;
for (unsigned k = 0; k < srcRank; ++k)
srcLin = srcLin * srcShape[k] +
((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);

dstToSrcMap[lin] = srcLin;
}

// Apply the pre-computed mapping
for (int64_t lin = 0; lin < dstCount; ++lin) {
results.push_back(srcElems.getResult(dstToSrcMap[lin]));
}
return success();
}

LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return foldToElementsFromElements(*this, results);
if (succeeded(foldToElementsFromElements(*this, results)))
return success();
return foldToElementsOfBroadcast(*this, results);
}


LogicalResult
ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ToElementsOp::Adaptor adaptor,
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3326,6 +3326,32 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x

// -----

// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
%v = vector.broadcast %s : f32 to vector<4xf32>
%e:4 = vector.to_elements %v : vector<4xf32>
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.to_elements
// CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
%v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
%e:6 = vector.to_elements %v : vector<3x2xf32>
// CHECK-NOT: vector.broadcast
// CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
// CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
}

// -----

// +---------------------------------------------------------------------------
// Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
Expand Down