@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
611
611
// IndexCastOp
612
612
// ===----------------------------------------------------------------------===//
613
613
614
- // Converts arith.index_cast to spirv.INotEqual if the target type is i1.
614
+ // / Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615
615
struct IndexCastIndexI1Pattern final
616
616
: public OpConversionPattern<arith::IndexCastOp> {
617
617
using OpConversionPattern::OpConversionPattern;
@@ -635,6 +635,30 @@ struct IndexCastIndexI1Pattern final
635
635
}
636
636
};
637
637
638
+ // / Converts arith.index_cast to spirv.Select if the source type is i1.
639
+ struct IndexCastI1IndexPattern final
640
+ : public OpConversionPattern<arith::IndexCastOp> {
641
+ using OpConversionPattern::OpConversionPattern;
642
+
643
+ LogicalResult
644
+ matchAndRewrite (arith::IndexCastOp op, OpAdaptor adaptor,
645
+ ConversionPatternRewriter &rewriter) const override {
646
+ if (!isBoolScalarOrVector (adaptor.getIn ().getType ()))
647
+ return failure ();
648
+
649
+ Type dstType = getTypeConverter ()->convertType (op.getType ());
650
+ if (!dstType)
651
+ return getTypeConversionFailure (rewriter, op);
652
+
653
+ Location loc = op.getLoc ();
654
+ Value zero = spirv::ConstantOp::getZero (dstType, loc, rewriter);
655
+ Value one = spirv::ConstantOp::getOne (dstType, loc, rewriter);
656
+ rewriter.replaceOpWithNewOp <spirv::SelectOp>(op, dstType, adaptor.getIn (),
657
+ one, zero);
658
+ return success ();
659
+ }
660
+ };
661
+
638
662
// ===----------------------------------------------------------------------===//
639
663
// ExtSIOp
640
664
// ===----------------------------------------------------------------------===//
@@ -1356,7 +1380,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
1356
1380
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1357
1381
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1358
1382
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1359
- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
1383
+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384
+ IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1360
1385
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1361
1386
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1362
1387
CmpIOpBooleanPattern, CmpIOpPattern,
0 commit comments