Skip to content

Commit c5a141b

Browse files
ianaylmshahneokuhar
authored
[mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (#156031)
Currently, `arith.index_cast` gets converted to `OpSConvert`: https://github.com/llvm/llvm-project/blob/9bf5bf3baf3c7aec82cdd235c6a2fd57b4dd55ab/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1331 [OpSConvert requires its operands to be of integer type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSConvert), which poses an issue for `i1` since SPIRV distinguishes between booleans and integers. As a result, the following example doesn't get converted, leaving behind illegal ops: ``` %0 = arith.index_cast %arg0 : index to i1 ``` This PR adds additional logic to convert `arith.index_casts` to SPIRV dialect when casting from `index` to `i1`. Converting `index_cast`s from `i1` to `index` is submitted as #155729. --------- Co-authored-by: Md Abdullah Shahneous Bari <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 303bea8 commit c5a141b

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,34 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
607607
}
608608
};
609609

610+
//===----------------------------------------------------------------------===//
611+
// IndexCastOp
612+
//===----------------------------------------------------------------------===//
613+
614+
// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615+
struct IndexCastIndexI1Pattern final
616+
: public OpConversionPattern<arith::IndexCastOp> {
617+
using OpConversionPattern::OpConversionPattern;
618+
619+
LogicalResult
620+
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621+
ConversionPatternRewriter &rewriter) const override {
622+
if (!isBoolScalarOrVector(op.getType()))
623+
return failure();
624+
625+
Type dstType = getTypeConverter()->convertType(op.getType());
626+
if (!dstType)
627+
return getTypeConversionFailure(rewriter, op);
628+
629+
Location loc = op.getLoc();
630+
Value zeroIdx =
631+
spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
632+
rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
633+
adaptor.getIn());
634+
return success();
635+
}
636+
};
637+
610638
//===----------------------------------------------------------------------===//
611639
// ExtSIOp
612640
//===----------------------------------------------------------------------===//
@@ -1328,7 +1356,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
13281356
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
13291357
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
13301358
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1331-
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1359+
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
13321360
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13331361
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
13341362
CmpIOpBooleanPattern, CmpIOpPattern,

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,31 @@ func.func @index_castui4(%arg0: index) {
734734
return
735735
}
736736

737+
// CHECK-LABEL: index_castindexi1_1
738+
func.func @index_castindexi1_1(%arg0: index) {
739+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
740+
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
741+
%0 = arith.index_cast %arg0 : index to i1
742+
return
743+
}
744+
745+
// CHECK-LABEL: index_castindexi1_2
746+
func.func @index_castindexi1_2(%arg0: vector<1xindex>) -> vector<1xi1> {
747+
// Single-element vectors do not exist in SPIRV.
748+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
749+
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
750+
%0 = arith.index_cast %arg0 : vector<1xindex> to vector<1xi1>
751+
return %0 : vector<1xi1>
752+
}
753+
754+
// CHECK-LABEL: index_castindexi1_3
755+
func.func @index_castindexi1_3(%arg0: vector<3xindex>) {
756+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
757+
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
758+
%0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>
759+
return
760+
}
761+
737762
// CHECK-LABEL: @bit_cast
738763
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
739764
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)