From a806d817f1df1ea0f13fa9a9852fe8290da7ffaa Mon Sep 17 00:00:00 2001 From: Andrea Faulds Date: Thu, 6 Feb 2025 16:58:35 +0100 Subject: [PATCH] [mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V This is a follow-up to 5df62bdc9be9c258c5ac45c8093b71e23777fa0e. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however: - The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940ee263ff50b7a71e21424913cc0266bf9d4). - The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this. - The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction. With these fixed, this commit removes the redundant static poison index handling, and updates the tests. --- mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp | 5 --- .../Conversion/VectorToSPIRV/CMakeLists.txt | 1 + .../VectorToSPIRV/VectorToSPIRV.cpp | 34 +++++++------------ .../VectorToSPIRV/VectorToSPIRVPass.cpp | 3 ++ .../Conversion/UBToSPIRV/ub-to-spirv.mlir | 3 +- .../VectorToSPIRV/vector-to-spirv.mlir | 9 ++--- 6 files changed, 23 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp index a3806189e4060..01c35cba48c49 100644 --- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern { matchAndRewrite(ub::PoisonOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { Type origType = op.getType(); - if (!origType.isIntOrIndexOrFloat()) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "unsupported type " << origType; - }); - Type resType = getTypeConverter()->convertType(origType); if (!resType) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt index bb9f793d7fe0f..f4cdb2cf95a30 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV MLIRSPIRVConversion MLIRVectorDialect MLIRTransforms + MLIRUBToSPIRV ) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 2c8bc149dc708..1ecb892a4ea92 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -182,17 +182,13 @@ struct VectorExtractOpConvert final if (std::optional id = getConstantIntValue(extractOp.getMixedPosition()[0])) { - // TODO: ExtractOp::fold() already can fold a static poison index to - // ub.poison; remove this once ub.poison can be converted to SPIR-V. - if (id == vector::ExtractOp::kPoisonIndex) { - // Arbitrary choice of poison result, intended to stick out. - Value zero = - spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter); - rewriter.replaceOp(extractOp, zero); - } else - rewriter.replaceOpWithNewOp( - extractOp, dstType, adaptor.getVector(), - rewriter.getI32ArrayAttr(id.value())); + if (id == vector::ExtractOp::kPoisonIndex) + return rewriter.notifyMatchFailure( + extractOp, + "Static use of poison index handled elsewhere (folded to poison)"); + rewriter.replaceOpWithNewOp( + extractOp, dstType, adaptor.getVector(), + rewriter.getI32ArrayAttr(id.value())); } else { Value sanitizedIndex = sanitizeDynamicIndex( rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0], @@ -306,16 +302,12 @@ struct VectorInsertOpConvert final if (std::optional id = getConstantIntValue(insertOp.getMixedPosition()[0])) { - // TODO: ExtractOp::fold() already can fold a static poison index to - // ub.poison; remove this once ub.poison can be converted to SPIR-V. - if (id == vector::InsertOp::kPoisonIndex) { - // Arbitrary choice of poison result, intended to stick out. - Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(), - insertOp.getLoc(), rewriter); - rewriter.replaceOp(insertOp, zero); - } else - rewriter.replaceOpWithNewOp( - insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); + if (id == vector::InsertOp::kPoisonIndex) + return rewriter.notifyMatchFailure( + insertOp, + "Static use of poison index handled elsewhere (folded to poison)"); + rewriter.replaceOpWithNewOp( + insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); } else { Value sanitizedIndex = sanitizeDynamicIndex( rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0], diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp index cc115b1d36826..0735e1ee0c677 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -49,6 +50,8 @@ void ConvertVectorToSPIRVPass::runOnOperation() { RewritePatternSet patterns(context); populateVectorToSPIRVPatterns(typeConverter, patterns); + // Used for folds, e.g. vector.extract[-1] -> ub.poison -> spirv.Undef. + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir index 771b53ad123b9..f497eb3bc552c 100644 --- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir +++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir @@ -13,8 +13,7 @@ func.func @check_poison() { %1 = ub.poison : i16 // CHECK: {{.*}} = spirv.Undef : f64 %2 = ub.poison : f64 -// TODO: vector is not covered yet -// CHECK: {{.*}} = ub.poison : vector<4xf32> +// CHECK: {{.*}} = spirv.Undef : vector<4xf32> %3 = ub.poison : vector<4xf32> return } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 5fd7324b1d3c7..3f0bf1962e299 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -175,9 +175,10 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { // ----- +// CHECK-LABEL: @extract_poison_idx +// CHECK: %[[R:.+]] = spirv.Undef : f32 +// CHECK: return %[[R]] func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 { - // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00 - // CHECK: return %[[ZERO]] %0 = vector.extract %arg0[-1] : f32 from vector<4xf32> return %0: f32 } @@ -285,8 +286,8 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { // ----- // CHECK-LABEL: @insert_poison_idx -// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00> -// CHECK: return %[[ZERO]] +// CHECK: %[[R:.+]] = spirv.Undef : vector<4xf32> +// CHECK: return %[[R]] func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32> return %1: vector<4xf32>