diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 265293b83f84c..026fd0f0c4774 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -607,6 +607,34 @@ struct UIToFPI1Pattern final : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +// Converts arith.index_cast to spirv.INotEqual if the target type is i1. +struct IndexCastIndexI1Pattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isBoolScalarOrVector(op.getType())) + return failure(); + + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + Location loc = op.getLoc(); + Value zeroIdx = + spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, zeroIdx, + adaptor.getIn()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -1328,7 +1356,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern, + TypeCastingOpPattern, IndexCastIndexI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 6e2352e706acc..0d2bda9c74807 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -734,6 +734,31 @@ func.func @index_castui4(%arg0: index) { return } +// CHECK-LABEL: index_castindexi1_1 +func.func @index_castindexi1_1(%arg0: index) { + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32 + %0 = arith.index_cast %arg0 : index to i1 + return +} + +// CHECK-LABEL: index_castindexi1_2 +func.func @index_castindexi1_2(%arg0: vector<1xindex>) -> vector<1xi1> { + // Single-element vectors do not exist in SPIRV. + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32 + %0 = arith.index_cast %arg0 : vector<1xindex> to vector<1xi1> + return %0 : vector<1xi1> +} + +// CHECK-LABEL: index_castindexi1_3 +func.func @index_castindexi1_3(%arg0: vector<3xindex>) { + // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32> + // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32> + %0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1> + return +} + // CHECK-LABEL: @bit_cast func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>