From f446699bc016509b7a7c6c0a2170b61d0b8709c8 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Wed, 27 Aug 2025 16:51:17 -0700 Subject: [PATCH 1/8] [mlir][spirv] Add pattern matching for arith.index_cast i1 to index --- .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 37 ++++++++++++++++++- .../ArithToSPIRV/arith-to-spirv.mlir | 7 ++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 265293b83f84c..172f322a12fd8 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -607,6 +607,41 @@ struct UIToFPI1Pattern final : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +/// Converts arith.index_cast to spirv.Select if the type of source is i1 or +/// vector of i1. +struct IndexCastI1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getOperands().front().getType(); + if (!srcType.isInteger(1)) + return failure(); + + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + // if (!dstType.isIndex()) { + // llvm::errs() << "why doesnt this work?\n"; + // return failure(); + // } + + auto *converter = this->template getTypeConverter(); + Location loc = op.getLoc(); + Type spirvI32T = converter->getIndexType(); + Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter); + Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter); + auto newOp = rewriter.replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -1328,7 +1363,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern, + TypeCastingOpPattern, IndexCastI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 6e2352e706acc..8bb63fff861ce 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -734,6 +734,13 @@ func.func @index_castui4(%arg0: index) { return } +// CHECK-LABEL: index_casti1_1 +func.func @index_casti1_1(%arg0 : i1) -> index { + // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32 + %0 = arith.index_cast %arg0 : i1 to index + return %0 : index +} + // CHECK-LABEL: @bit_cast func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32> From 1ba2dcf5c81b98b67ebf95c9052c28119f230e99 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 28 Aug 2025 13:37:27 -0700 Subject: [PATCH 2/8] Remove redundancy, add missing lit checks --- .../lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 16 +++++----------- .../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 8 +++++--- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 172f322a12fd8..b9e04e456ff72 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -613,7 +613,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern { /// Converts arith.index_cast to spirv.Select if the type of source is i1 or /// vector of i1. -struct IndexCastI1Pattern final : public OpConversionPattern { +struct IndexCastI1IndexPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -626,17 +626,11 @@ struct IndexCastI1Pattern final : public OpConversionPattern Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); - // if (!dstType.isIndex()) { - // llvm::errs() << "why doesnt this work?\n"; - // return failure(); - // } - auto *converter = this->template getTypeConverter(); Location loc = op.getLoc(); - Type spirvI32T = converter->getIndexType(); - Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter); - Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter); - auto newOp = rewriter.replaceOpWithNewOp( + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } @@ -1363,7 +1357,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern, IndexCastI1Pattern, + TypeCastingOpPattern, IndexCastI1IndexPattern, TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 8bb63fff861ce..938a5ccfed542 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -734,9 +734,11 @@ func.func @index_castui4(%arg0: index) { return } -// CHECK-LABEL: index_casti1_1 -func.func @index_casti1_1(%arg0 : i1) -> index { - // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32 +// CHECK-LABEL: index_casti1index_1 +func.func @index_casti1index_1(%arg0 : i1) -> index { + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 %0 = arith.index_cast %arg0 : i1 to index return %0 : index } From ac32e57f1fdfc358e39e9e01a5787f38d0d3c513 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 28 Aug 2025 13:51:03 -0700 Subject: [PATCH 3/8] remove redundant return --- mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 938a5ccfed542..e86b04527383d 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -735,12 +735,12 @@ func.func @index_castui4(%arg0: index) { } // CHECK-LABEL: index_casti1index_1 -func.func @index_casti1index_1(%arg0 : i1) -> index { +func.func @index_casti1index_1(%arg0 : i1) { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 %0 = arith.index_cast %arg0 : i1 to index - return %0 : index + return } // CHECK-LABEL: @bit_cast From 13f3d477dd241f50d478e6536b8d2f57547e5fd4 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 29 Aug 2025 08:35:59 -0700 Subject: [PATCH 4/8] clang-format --- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index b9e04e456ff72..b55322816fd31 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -613,7 +613,8 @@ struct UIToFPI1Pattern final : public OpConversionPattern { /// Converts arith.index_cast to spirv.Select if the type of source is i1 or /// vector of i1. -struct IndexCastI1IndexPattern final : public OpConversionPattern { +struct IndexCastI1IndexPattern final + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult From e61ff6a7774e437a0319f84e5aa9c363bd9f2ff7 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 29 Aug 2025 13:07:37 -0700 Subject: [PATCH 5/8] rewrite comments to conform with sister PR --- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 +-- mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index b55322816fd31..09d2a1bbf9d45 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -611,8 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern { // IndexCastOp //===----------------------------------------------------------------------===// -/// Converts arith.index_cast to spirv.Select if the type of source is i1 or -/// vector of i1. +/// Converts arith.index_cast to spirv.Select if the source type is i1 struct IndexCastI1IndexPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index e86b04527383d..7968fce644e4b 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -734,8 +734,8 @@ func.func @index_castui4(%arg0: index) { return } -// CHECK-LABEL: index_casti1index_1 -func.func @index_casti1index_1(%arg0 : i1) { +// CHECK-LABEL: index_casti1index +func.func @index_casti1index(%arg0 : i1) { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 From ac383ab7e3362c2e0f5412294b32e989acf3f506 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Tue, 2 Sep 2025 10:49:53 -0700 Subject: [PATCH 6/8] Amend reviewer comments, add vector support --- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 8 ++++---- .../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 09d2a1bbf9d45..af297b1c918bf 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern { // IndexCastOp //===----------------------------------------------------------------------===// -/// Converts arith.index_cast to spirv.Select if the source type is i1 +/// Converts arith.index_cast to spirv.Select if the source type is i1. struct IndexCastI1IndexPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -619,8 +619,8 @@ struct IndexCastI1IndexPattern final LogicalResult matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type srcType = adaptor.getOperands().front().getType(); - if (!srcType.isInteger(1)) + Type srcType = adaptor.getIn().getType(); + if (!isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); @@ -631,7 +631,7 @@ struct IndexCastI1IndexPattern final Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); + op, dstType, adaptor.getIn(), one, zero); return success(); } }; diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 7968fce644e4b..9f575250aab2e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -734,8 +734,8 @@ func.func @index_castui4(%arg0: index) { return } -// CHECK-LABEL: index_casti1index -func.func @index_casti1index(%arg0 : i1) { +// CHECK-LABEL: index_casti1index_1 +func.func @index_casti1index_1(%arg0 : i1) { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 @@ -743,6 +743,15 @@ func.func @index_casti1index(%arg0 : i1) { return } +// CHECK-LABEL: index_casti1index_2 +func.func @index_casti1index_2(%arg0 : vector<3xi1>) { + // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32> + // CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32> + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32> + %0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex> + return +} + // CHECK-LABEL: @bit_cast func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32> From fba473fb6f1fbc22bee355cb4cef287c243c1d5f Mon Sep 17 00:00:00 2001 From: Ian Li Date: Tue, 2 Sep 2025 10:52:19 -0700 Subject: [PATCH 7/8] clang-format --- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index af297b1c918bf..f53a9edb52993 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -619,8 +619,7 @@ struct IndexCastI1IndexPattern final LogicalResult matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type srcType = adaptor.getIn().getType(); - if (!isBoolScalarOrVector(srcType)) + if (!isBoolScalarOrVector(adaptor.getIn().getType())) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); @@ -630,8 +629,8 @@ struct IndexCastI1IndexPattern final Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp( - op, dstType, adaptor.getIn(), one, zero); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), + one, zero); return success(); } }; From cda887aed906a0c2c765f8381abdb70a25925a0e Mon Sep 17 00:00:00 2001 From: Ian Li Date: Tue, 2 Sep 2025 18:37:43 -0700 Subject: [PATCH 8/8] Add test for single-element vectors --- .../test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 9f575250aab2e..02f37d9431d36 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -744,7 +744,17 @@ func.func @index_casti1index_1(%arg0 : i1) { } // CHECK-LABEL: index_casti1index_2 -func.func @index_casti1index_2(%arg0 : vector<3xi1>) { +func.func @index_casti1index_2(%arg0 : vector<1xi1>) -> vector<1xindex> { + // Single-element vectors do not exist in SPIRV. + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32 + %0 = arith.index_cast %arg0 : vector<1xi1> to vector<1xindex> + return %0 : vector<1xindex> +} + +// CHECK-LABEL: index_casti1index_3 +func.func @index_casti1index_3(%arg0 : vector<3xi1>) { // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32> // CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32> // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32>