@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
611611// IndexCastOp
612612// ===----------------------------------------------------------------------===//
613613
614- // Converts arith.index_cast to spirv.INotEqual if the target type is i1.
614+ // / Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615615struct IndexCastIndexI1Pattern final
616616 : public OpConversionPattern<arith::IndexCastOp> {
617617 using OpConversionPattern::OpConversionPattern;
@@ -635,6 +635,30 @@ struct IndexCastIndexI1Pattern final
635635 }
636636};
637637
638+ // / Converts arith.index_cast to spirv.Select if the source type is i1.
639+ struct IndexCastI1IndexPattern final
640+ : public OpConversionPattern<arith::IndexCastOp> {
641+ using OpConversionPattern::OpConversionPattern;
642+
643+ LogicalResult
644+ matchAndRewrite (arith::IndexCastOp op, OpAdaptor adaptor,
645+ ConversionPatternRewriter &rewriter) const override {
646+ if (!isBoolScalarOrVector (adaptor.getIn ().getType ()))
647+ return failure ();
648+
649+ Type dstType = getTypeConverter ()->convertType (op.getType ());
650+ if (!dstType)
651+ return getTypeConversionFailure (rewriter, op);
652+
653+ Location loc = op.getLoc ();
654+ Value zero = spirv::ConstantOp::getZero (dstType, loc, rewriter);
655+ Value one = spirv::ConstantOp::getOne (dstType, loc, rewriter);
656+ rewriter.replaceOpWithNewOp <spirv::SelectOp>(op, dstType, adaptor.getIn (),
657+ one, zero);
658+ return success ();
659+ }
660+ };
661+
638662// ===----------------------------------------------------------------------===//
639663// ExtSIOp
640664// ===----------------------------------------------------------------------===//
@@ -1356,7 +1380,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
13561380 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
13571381 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
13581382 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1359- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
1383+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384+ IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
13601385 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13611386 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
13621387 CmpIOpBooleanPattern, CmpIOpPattern,
0 commit comments