Skip to content
30 changes: 29 additions & 1 deletion mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,34 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
}
};

//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//

// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
struct IndexCastIndexI1Pattern final
: public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isBoolScalarOrVector(op.getType()))
return failure();

Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);

Location loc = op.getLoc();
Value zeroIdx =
spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
adaptor.getIn());
return success();
}
};

//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1328,7 +1356,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,31 @@ func.func @index_castui4(%arg0: index) {
return
}

// CHECK-LABEL: index_castindexi1_1
func.func @index_castindexi1_1(%arg0 : index) {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
%0 = arith.index_cast %arg0 : index to i1
return
}

// CHECK-LABEL: index_castindexi1_2
func.func @index_castindexi1_2(%arg0 : vector<1xindex>) -> vector<1xi1> {
// Single-element vectors do not exist in SPIRV.
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
%0 = arith.index_cast %arg0 : vector<1xindex> to vector<1xi1>
return %0 : vector<1xi1>
}

// CHECK-LABEL: index_castindexi1_3
func.func @index_castindexi1_3(%arg0 : vector<3xindex>) {
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
%0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>
return
}

// CHECK-LABEL: @bit_cast
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
Expand Down