@@ -607,6 +607,34 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
607607 }
608608};
609609
610+ // ===----------------------------------------------------------------------===//
611+ // IndexCastOp
612+ // ===----------------------------------------------------------------------===//
613+
614+ // Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615+ struct IndexCastIndexI1Pattern final
616+ : public OpConversionPattern<arith::IndexCastOp> {
617+ using OpConversionPattern::OpConversionPattern;
618+
619+ LogicalResult
620+ matchAndRewrite (arith::IndexCastOp op, OpAdaptor adaptor,
621+ ConversionPatternRewriter &rewriter) const override {
622+ if (!isBoolScalarOrVector (op.getType ()))
623+ return failure ();
624+
625+ Type dstType = getTypeConverter ()->convertType (op.getType ());
626+ if (!dstType)
627+ return getTypeConversionFailure (rewriter, op);
628+
629+ Location loc = op.getLoc ();
630+ Value zeroIdx =
631+ spirv::ConstantOp::getZero (adaptor.getIn ().getType (), loc, rewriter);
632+ rewriter.replaceOpWithNewOp <spirv::INotEqualOp>(op, dstType, zeroIdx,
633+ adaptor.getIn ());
634+ return success ();
635+ }
636+ };
637+
610638// ===----------------------------------------------------------------------===//
611639// ExtSIOp
612640// ===----------------------------------------------------------------------===//
@@ -1328,7 +1356,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
13281356 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
13291357 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
13301358 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1331- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1359+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
13321360 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13331361 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
13341362 CmpIOpBooleanPattern, CmpIOpPattern,
0 commit comments