Skip to content

Commit 5fdd3a1

Browse files
yangtetrisYang Bainewling
authored
[mlir][vector] Follow-up improvements for multi-dimensional vector.from_elements support (#154664)
This PR is a follow-up to #151175 that supported lowering multi-dimensional `vector.from_elements` op to LLVM by introducing a unrolling pattern. ## Changes ### Add `vector.shape_cast` based flattening pattern for `vector.from_elements` This change introduces a new linearization pattern that uses `vector.shape_cast` to flatten multi-dimensional `vector.from_elements` operations. This provides an alternative approach to the unrolling-based method introduced in #151175. **Example:** ```mlir // Before %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> // After %flat = vector.from_elements %e0, %e1, %e2, %e3 : vector<4xf32> %result = vector.shape_cast %flat : vector<4xf32> to vector<2x2xf32> ``` --------- Co-authored-by: Yang Bai <[email protected]> Co-authored-by: James Newling <[email protected]>
1 parent 5a9f103 commit 5fdd3a1

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
762762
}
763763
};
764764

765+
/// This pattern linearizes `vector.from_elements` operations by converting
766+
/// the result type to a 1-D vector while preserving all element values.
767+
/// The transformation creates a linearized `vector.from_elements` followed by
768+
/// a `vector.shape_cast` to restore the original multidimensional shape.
769+
///
770+
/// Example:
771+
///
772+
/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
773+
///
774+
/// is converted to:
775+
///
776+
/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
777+
/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
778+
///
779+
struct LinearizeVectorFromElements final
780+
: public OpConversionPattern<vector::FromElementsOp> {
781+
using OpConversionPattern::OpConversionPattern;
782+
LinearizeVectorFromElements(const TypeConverter &typeConverter,
783+
MLIRContext *context, PatternBenefit benefit = 1)
784+
: OpConversionPattern(typeConverter, context, benefit) {}
785+
LogicalResult
786+
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
787+
ConversionPatternRewriter &rewriter) const override {
788+
VectorType dstTy =
789+
getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
790+
assert(dstTy && "vector type destination expected.");
791+
792+
OperandRange elements = fromElementsOp.getElements();
793+
assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
794+
"expected same number of elements");
795+
rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
796+
elements);
797+
return success();
798+
}
799+
};
800+
765801
} // namespace
766802

767803
/// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
854890
patterns
855891
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
856892
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
857-
LinearizeVectorStore>(typeConverter, patterns.getContext());
893+
LinearizeVectorStore, LinearizeVectorFromElements>(
894+
typeConverter, patterns.getContext());
858895
}
859896

860897
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector
524524
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
525525
return
526526
}
527+
528+
// -----
529+
530+
// Test pattern LinearizeVectorFromElements.
531+
532+
// CHECK-LABEL: test_vector_from_elements
533+
// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
534+
func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
535+
// CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
536+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
537+
// CHECK: return %[[CAST]] : vector<2x2xf32>
538+
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
539+
return %1 : vector<2x2xf32>
540+
}

0 commit comments

Comments
 (0)