diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 252c0b72456df..bb5d686fdd4d2 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -806,6 +806,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ let results = (outs Variadic:$elements); let assemblyFormat = "$source attr-dict `:` type($source)"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_FromElementsOp : Vector_Op<"from_elements", [ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e889302f..14e235f253f69 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -47,6 +47,7 @@ #include #include +#include #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" // Pull in all enum type and utility function definitions. @@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp, return success(); } +/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only. +/// +/// Example: +/// %b = vector.broadcast %x : i32 to vector<3xf32> +/// %e:3 = vector.to_elements %b : vector<3xf32> +/// user_op %e#0, %e#1, %e#2 +/// becomes: +/// user_op %x, %x, %x +/// +/// The vector source case is handled by a canonicalization pattern. +static LogicalResult +foldToElementsOfBroadcast(ToElementsOp toElementsOp, + SmallVectorImpl &results) { + auto bcastOp = toElementsOp.getSource().getDefiningOp(); + if (!bcastOp) + return failure(); + // Vectors are handled in the ToElementsOfBroadcast RewritePattern. + if (isa(bcastOp.getSource().getType())) + return failure(); + + auto resultVecType = cast(toElementsOp.getSource().getType()); + + Value scalar = bcastOp.getSource(); + results.assign(resultVecType.getNumElements(), scalar); + return success(); +} + LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - return foldToElementsFromElements(*this, results); + if (succeeded(foldToElementsFromElements(*this, results))) + return success(); + return foldToElementsOfBroadcast(*this, results); } LogicalResult @@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, return success(); } +/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a +/// vector. +/// - Build `vector.to_elements %v` and remap each destination element to the +/// corresponding source element using broadcast rules (match or 1 → +/// replicate). +/// +/// Example: +/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32> +/// %e:6 = vector.to_elements %v : vector<3x2xf32> +/// becomes: +/// %src_elems:2 = vector.to_elements %src : vector<2xf32> +/// // uses: %src_elems#0, %src_elems#1, %src_elems#0, +/// // %src_elems#1, %src_elems#0, %src_elems#1 +struct ToElementsOfBroadcast final : OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(ToElementsOp toElementsOp, + PatternRewriter &rewriter) const override { + auto bcastOp = toElementsOp.getSource().getDefiningOp(); + if (!bcastOp) + return failure(); + + // Only handle broadcasts from a vector source here. + auto srcType = dyn_cast(bcastOp.getSource().getType()); + if (!srcType) + return failure(); + + auto dstType = cast(toElementsOp.getSource().getType()); + + ArrayRef dstShape = dstType.getShape(); + ArrayRef srcShape = srcType.getShape(); + + int64_t dstRank = dstShape.size(); + int64_t srcRank = srcShape.size(); + + // Create elements for the broadcast source vector. + auto srcElems = vector::ToElementsOp::create( + rewriter, toElementsOp.getLoc(), bcastOp.getSource()); + + int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1, + std::multiplies()); + + SmallVector replacements; + replacements.reserve(dstCount); + + // For each element of the destination, determine which element of the + // source should be used. We walk all destination positions using a single + // counter, decode it into per-dimension indices, then build the matching + // source position: use the same index where sizes match, and use 0 where + // the source size is 1 (replication). This mapping is needed so we can + // replace each result of to_elements with the corresponding element from + // the broadcast source. + // Inner-dimension stretch example: + // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32> + // %e:12 = vector.to_elements %v : vector<2x3x2xf32> + // becomes: + // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32> + // // uses: %src_elems#0, %src_elems#1, %src_elems#0, + // // %src_elems#1, %src_elems#0, %src_elems#1, + // // %src_elems#2, %src_elems#3, %src_elems#2, + // // %src_elems#3, %src_elems#2, %src_elems#3 + + // Row-major strides for the destination shape. + SmallVector dstStrides = computeStrides(dstShape); + // Row-major strides for the source shape. + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstIdx(dstRank); + SmallVector srcIdx(srcRank); + for (int64_t lin = 0; lin < dstCount; ++lin) { + // Convert linear destination index to per-dimension indices. + dstIdx = delinearize(lin, dstStrides); + for (int64_t k = 0; k < srcRank; ++k) + srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]; + // Convert per-dimension source indices back to a linear index. + int64_t srcLin = linearize(srcIdx, srcStrides); + replacements.push_back(srcElems.getResult(srcLin)); + } + + rewriter.replaceOp(toElementsOp, replacements); + return success(); + } +}; + +void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index bccf5d5b77b0e..7e81298dd9e70 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3326,6 +3326,46 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x // ----- +// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds +// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32) +func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) { + %v = vector.broadcast %s : f32 to vector<4xf32> + %e:4 = vector.to_elements %v : vector<4xf32> + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.to_elements + // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]] + return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @to_elements_of_vector_broadcast +// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) +func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) { + %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32> + %e:6 = vector.to_elements %v : vector<3x2xf32> + // CHECK-NOT: vector.broadcast + // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]] + // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1 + return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @to_elements_of_vector_broadcast_inner_dim +// CHECK-SAME: (%[[V:.*]]: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) +func.func @to_elements_of_vector_broadcast_inner_dim(%v: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { + %b = vector.broadcast %v : vector<2x1x2xf32> to vector<2x3x2xf32> + %e:12 = vector.to_elements %b : vector<2x3x2xf32> + // CHECK-NOT: vector.broadcast + // CHECK: %[[SRC:.*]]:4 = vector.to_elements %[[V]] : vector<2x1x2xf32> + // CHECK: return %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3 + return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5, %e#6, %e#7, %e#8, %e#9, %e#10, %e#11 : + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 +} + +// ----- + // +--------------------------------------------------------------------------- // Tests for foldFromElementsToConstant // +---------------------------------------------------------------------------