diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 026fd0f0c4774..b99a8a3fe17b1 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern { // IndexCastOp //===----------------------------------------------------------------------===// -// Converts arith.index_cast to spirv.INotEqual if the target type is i1. +/// Converts arith.index_cast to spirv.INotEqual if the target type is i1. struct IndexCastIndexI1Pattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -635,6 +635,30 @@ struct IndexCastIndexI1Pattern final } }; +/// Converts arith.index_cast to spirv.Select if the source type is i1. +struct IndexCastI1IndexPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isBoolScalarOrVector(adaptor.getIn().getType())) + return failure(); + + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), + one, zero); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -1356,7 +1380,8 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern, IndexCastIndexI1Pattern, + TypeCastingOpPattern, + IndexCastIndexI1Pattern, IndexCastI1IndexPattern, TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 0d2bda9c74807..3cb5294598994 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -759,6 +759,34 @@ func.func @index_castindexi1_3(%arg0: vector<3xindex>) { return } +// CHECK-LABEL: index_casti1index_1 +func.func @index_casti1index_1(%arg0 : i1) { + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 + %0 = arith.index_cast %arg0 : i1 to index + return +} + +// CHECK-LABEL: index_casti1index_2 +func.func @index_casti1index_2(%arg0 : vector<1xi1>) -> vector<1xindex> { + // Single-element vectors do not exist in SPIRV. + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 + %0 = arith.index_cast %arg0 : vector<1xi1> to vector<1xindex> + return %0 : vector<1xindex> +} + +// CHECK-LABEL: index_casti1index_3 +func.func @index_casti1index_3(%arg0 : vector<3xi1>) { + // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32> + // CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32> + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32> + %0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex> + return +} + // CHECK-LABEL: @bit_cast func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>