Skip to content

Commit b812e3d

Browse files
[mlir][vector] Add LinearizeVectorToElements (#157740)
Co-authored-by: James Newling <[email protected]>
1 parent 9d19250 commit b812e3d

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final
798798
}
799799
};
800800

801+
/// This pattern linearizes the operand in `vector.to_elements` operations
802+
/// by converting the source type to a 1-D vector while preserving all element
803+
/// values. The transformation creates a linearized `vector.shape_cast`
804+
/// followed by a `vector.to_elements`.
805+
///
806+
/// Example:
807+
///
808+
/// %0:4 = vector.to_elements %v : vector<2x2xf32>
809+
///
810+
/// is converted to:
811+
///
812+
/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
813+
/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
814+
///
815+
struct LinearizeVectorToElements final
816+
: public OpConversionPattern<vector::ToElementsOp> {
817+
using OpConversionPattern::OpConversionPattern;
818+
819+
LinearizeVectorToElements(const TypeConverter &typeConverter,
820+
MLIRContext *context, PatternBenefit benefit = 1)
821+
: OpConversionPattern(typeConverter, context, benefit) {}
822+
823+
LogicalResult
824+
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
825+
ConversionPatternRewriter &rewriter) const override {
826+
827+
VectorType vecType = toElementsOp.getSource().getType();
828+
if (vecType.getRank() <= 1)
829+
return rewriter.notifyMatchFailure(
830+
toElementsOp, "the rank is already less than or equal to 1");
831+
832+
assert(vecType.getNumScalableDims() == 0 &&
833+
"to_elements does not support scalable vectors");
834+
auto vec1DType =
835+
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
836+
Value shapeCast = vector::ShapeCastOp::create(
837+
rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
838+
auto newToElementsOp =
839+
vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(),
840+
toElementsOp.getResultTypes(), shapeCast);
841+
rewriter.replaceOp(toElementsOp, newToElementsOp);
842+
return success();
843+
}
844+
};
845+
801846
} // namespace
802847

803848
/// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
890935
patterns
891936
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
892937
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
893-
LinearizeVectorStore, LinearizeVectorFromElements>(
894-
typeConverter, patterns.getContext());
938+
LinearizeVectorStore, LinearizeVectorFromElements,
939+
LinearizeVectorToElements>(typeConverter, patterns.getContext());
895940
}
896941

897942
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
538538
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
539539
return %1 : vector<2x2xf32>
540540
}
541+
542+
// -----
543+
544+
// CHECK-LABEL: func.func @to_elements_1d(
545+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
546+
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
547+
// CHECK: return %[[RES]]#0, %[[RES]]#1
548+
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
549+
%0:2 = vector.to_elements %arg0 : vector<2xf32>
550+
return %0#0, %0#1 : f32, f32
551+
}
552+
553+
// -----
554+
555+
// CHECK-LABEL: func.func @to_elements_2d(
556+
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
557+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
558+
// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
559+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
560+
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
561+
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
562+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
563+
}

0 commit comments

Comments
 (0)