Skip to content

Commit cf86eb6

Browse files
committed
Fix compat with vector<?xi1>
1 parent 2815dbd commit cf86eb6

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,12 +619,11 @@ struct IndexCastIndexI1Pattern final
619619
LogicalResult
620620
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621621
ConversionPatternRewriter &rewriter) const override {
622-
Type srcType = adaptor.getIn().getType();
623-
if (!op.getType().isInteger(1))
622+
if (!isBoolScalarOrVector(op.getType()))
624623
return failure();
625624

626625
Location loc = op.getLoc();
627-
Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
626+
Value zeroIdx = spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
628627
rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, op.getType(), zeroIdx,
629628
adaptor.getIn());
630629
return success();

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,22 @@ func.func @index_castui4(%arg0: index) {
734734
return
735735
}
736736

737-
// CHECK-LABEL: index_castindexi1
738-
func.func @index_castindexi1(%arg0 : index) {
737+
// CHECK-LABEL: index_castindexi1_1
738+
func.func @index_castindexi1_1(%arg0 : index) {
739739
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
740740
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
741741
%0 = arith.index_cast %arg0 : index to i1
742742
return
743743
}
744744

745+
// CHECK-LABEL: index_castindexi1_2
746+
func.func @index_castindexi1_2(%arg0 : vector<3xindex>) {
747+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
748+
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
749+
%0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>
750+
return
751+
}
752+
745753
// CHECK-LABEL: @bit_cast
746754
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
747755
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)