Skip to content

Commit ac383ab

Browse files
committed
Amend reviewer comments, add vector support
1 parent e61ff6a commit ac383ab

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,16 +611,16 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
611611
// IndexCastOp
612612
//===----------------------------------------------------------------------===//
613613

614-
/// Converts arith.index_cast to spirv.Select if the source type is i1
614+
/// Converts arith.index_cast to spirv.Select if the source type is i1.
615615
struct IndexCastI1IndexPattern final
616616
: public OpConversionPattern<arith::IndexCastOp> {
617617
using OpConversionPattern::OpConversionPattern;
618618

619619
LogicalResult
620620
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621621
ConversionPatternRewriter &rewriter) const override {
622-
Type srcType = adaptor.getOperands().front().getType();
623-
if (!srcType.isInteger(1))
622+
Type srcType = adaptor.getIn().getType();
623+
if (!isBoolScalarOrVector(srcType))
624624
return failure();
625625

626626
Type dstType = getTypeConverter()->convertType(op.getType());
@@ -631,7 +631,7 @@ struct IndexCastI1IndexPattern final
631631
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
632632
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
633633
rewriter.replaceOpWithNewOp<spirv::SelectOp>(
634-
op, dstType, adaptor.getOperands().front(), one, zero);
634+
op, dstType, adaptor.getIn(), one, zero);
635635
return success();
636636
}
637637
};

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,15 +734,24 @@ func.func @index_castui4(%arg0: index) {
734734
return
735735
}
736736

737-
// CHECK-LABEL: index_casti1index
738-
func.func @index_casti1index(%arg0 : i1) {
737+
// CHECK-LABEL: index_casti1index_1
738+
func.func @index_casti1index_1(%arg0 : i1) {
739739
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
740740
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
741741
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
742742
%0 = arith.index_cast %arg0 : i1 to index
743743
return
744744
}
745745

746+
// CHECK-LABEL: index_casti1index_2
747+
func.func @index_casti1index_2(%arg0 : vector<3xi1>) {
748+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
749+
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32>
750+
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32>
751+
%0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex>
752+
return
753+
}
754+
746755
// CHECK-LABEL: @bit_cast
747756
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
748757
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)