From 0c26bd77df30543f65898ceb53a4ed61d4af7160 Mon Sep 17 00:00:00 2001 From: Andrea Faulds Date: Wed, 4 Dec 2024 17:11:18 +0100 Subject: [PATCH] [mlir][spirv][vector] Support converting vector.from_elements to SPIR-V --- .../VectorToSPIRV/VectorToSPIRV.cpp | 31 +++++++++++++++++-- .../VectorToSPIRV/vector-to-spirv.mlir | 29 +++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 656b1cb3e99a1..d3731db1ce55c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -220,6 +220,32 @@ struct VectorFmaOpConvert final : public OpConversionPattern { } }; +struct VectorFromElementsOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = getTypeConverter()->convertType(op.getType()); + if (!resultType) + return failure(); + OperandRange elements = op.getElements(); + if (isa(resultType)) { + // In the case with a single scalar operand / single-element result, + // pass through the scalar. + rewriter.replaceOp(op, elements[0]); + return success(); + } + // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional + // vector.from_elements cases should not need to be handled, only 1d. + assert(cast(resultType).getRank() == 1); + rewriter.replaceOpWithNewOp(op, resultType, + elements); + return success(); + } +}; + struct VectorInsertOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -952,8 +978,9 @@ void mlir::populateVectorToSPIRVPatterns( VectorBitcastConvert, VectorBroadcastConvert, VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, + VectorFmaOpConvert, VectorFromElementsOpConvert, + VectorInsertElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern, VectorReductionPattern, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8796f153c4911..103148633bf97 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -217,6 +217,35 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { // ----- +// CHECK-LABEL: @from_elements_0d +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: return %[[RETVAL]] +func.func @from_elements_0d(%arg0 : f32) -> vector { + %0 = vector.from_elements %arg0 : vector + return %0: vector +} + +// CHECK-LABEL: @from_elements_1x +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: return %[[RETVAL]] +func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> { + %0 = vector.from_elements %arg0 : vector<1xf32> + return %0: vector<1xf32> +} + +// CHECK-LABEL: @from_elements_3x +// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32 +// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32> +// CHECK: return %[[RETVAL]] +func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + // CHECK-LABEL: @insert // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32 // CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>