Skip to content

Commit 6aa9d92

Browse files
authored
[mlir][spirv] Add pattern matching for arith.index_cast i1 to index for ArithToSPIRV (#155729)
Currently, `arith.index_cast` gets converted to `OpSConvert`: https://github.com/llvm/llvm-project/blob/9bf5bf3baf3c7aec82cdd235c6a2fd57b4dd55ab/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1331 [OpSConvert requires its operands to be of integer type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSConvert), which poses an issue for `i1` since SPIRV distinguishes between booleans and integers. As a result, the following example doesn't get converted, leaving behind illegal ops: ``` %0 = arith.index_cast %arg0 : i1 to index ``` This PR adds additional logic to convert `arith.index_casts` to SPIRV dialect when casting from `i1` to `index`. Converting `index_cast`s from `index` to `i1` is a part of #156031.
1 parent ded5f43 commit 6aa9d92

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
611611
// IndexCastOp
612612
//===----------------------------------------------------------------------===//
613613

614-
// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
614+
/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615615
struct IndexCastIndexI1Pattern final
616616
: public OpConversionPattern<arith::IndexCastOp> {
617617
using OpConversionPattern::OpConversionPattern;
@@ -635,6 +635,30 @@ struct IndexCastIndexI1Pattern final
635635
}
636636
};
637637

638+
/// Converts arith.index_cast to spirv.Select if the source type is i1.
639+
struct IndexCastI1IndexPattern final
640+
: public OpConversionPattern<arith::IndexCastOp> {
641+
using OpConversionPattern::OpConversionPattern;
642+
643+
LogicalResult
644+
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
645+
ConversionPatternRewriter &rewriter) const override {
646+
if (!isBoolScalarOrVector(adaptor.getIn().getType()))
647+
return failure();
648+
649+
Type dstType = getTypeConverter()->convertType(op.getType());
650+
if (!dstType)
651+
return getTypeConversionFailure(rewriter, op);
652+
653+
Location loc = op.getLoc();
654+
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
655+
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
656+
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
657+
one, zero);
658+
return success();
659+
}
660+
};
661+
638662
//===----------------------------------------------------------------------===//
639663
// ExtSIOp
640664
//===----------------------------------------------------------------------===//
@@ -1356,7 +1380,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
13561380
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
13571381
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
13581382
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1359-
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
1383+
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384+
IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
13601385
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13611386
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
13621387
CmpIOpBooleanPattern, CmpIOpPattern,

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,34 @@ func.func @index_castindexi1_3(%arg0: vector<3xindex>) {
759759
return
760760
}
761761

762+
// CHECK-LABEL: index_casti1index_1
763+
func.func @index_casti1index_1(%arg0 : i1) {
764+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
765+
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
766+
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
767+
%0 = arith.index_cast %arg0 : i1 to index
768+
return
769+
}
770+
771+
// CHECK-LABEL: index_casti1index_2
772+
func.func @index_casti1index_2(%arg0 : vector<1xi1>) -> vector<1xindex> {
773+
// Single-element vectors do not exist in SPIRV.
774+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
775+
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
776+
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
777+
%0 = arith.index_cast %arg0 : vector<1xi1> to vector<1xindex>
778+
return %0 : vector<1xindex>
779+
}
780+
781+
// CHECK-LABEL: index_casti1index_3
782+
func.func @index_casti1index_3(%arg0 : vector<3xi1>) {
783+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
784+
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32>
785+
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32>
786+
%0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex>
787+
return
788+
}
789+
762790
// CHECK-LABEL: @bit_cast
763791
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
764792
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)