From c004d4f4e9e08230c7dcddf40f52bbd1e725cfb9 Mon Sep 17 00:00:00 2001 From: Bynaryman Date: Tue, 25 Nov 2025 09:39:08 +0100 Subject: [PATCH] Integration of our custom pass into CIRCT --- include/circt/Transforms/Passes.h | 1 + include/circt/Transforms/Passes.td | 14 ++ lib/Transforms/CMakeLists.txt | 1 + lib/Transforms/ConvertIndexToUInt.cpp | 154 +++++++++++++++++++++ test/Transforms/convert-index-to-uint.mlir | 81 +++++++++++ 5 files changed, 251 insertions(+) create mode 100644 lib/Transforms/ConvertIndexToUInt.cpp create mode 100644 test/Transforms/convert-index-to-uint.mlir diff --git a/include/circt/Transforms/Passes.h b/include/circt/Transforms/Passes.h index 9b3578d3e318..e8804b0dd06e 100644 --- a/include/circt/Transforms/Passes.h +++ b/include/circt/Transforms/Passes.h @@ -40,6 +40,7 @@ void populateArithToCombPatterns(mlir::RewritePatternSet &patterns, TypeConverter &typeConverter); std::unique_ptr createMapArithToCombPass(); +std::unique_ptr createConvertIndexToUIntPass(); std::unique_ptr createFlattenMemRefPass(); std::unique_ptr createFlattenMemRefCallsPass(); std::unique_ptr createStripDebugInfoWithPredPass( diff --git a/include/circt/Transforms/Passes.td b/include/circt/Transforms/Passes.td index 916419c381af..ce849b4dde7a 100644 --- a/include/circt/Transforms/Passes.td +++ b/include/circt/Transforms/Passes.td @@ -70,6 +70,20 @@ def MapArithToCombPass : Pass<"map-arith-to-comb"> { let dependentDialects = ["circt::hw::HWDialect, mlir::arith::ArithDialect, circt::comb::CombDialect"]; } +def ConvertIndexToUInt : Pass<"convert-index-to-uint", "::mlir::ModuleOp"> { + let summary = "Rewrite index-based switch comparisons into unsigned integer ops."; + let description = [{ + Replace `arith.cmpi` operations whose operands are `index` values (often + produced when lowering `scf.index_switch`) with comparisons over the + original integer type so that downstream hardware mapping passes (e.g. + `--map-arith-to-comb`) do not encounter unsupported index-typed arithmetic. + The pass converts any associated index constants and erases the redundant + casts that become dead afterwards. + }]; + let constructor = "circt::createConvertIndexToUIntPass()"; + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + def InsertMergeBlocks : Pass<"insert-merge-blocks", "::mlir::ModuleOp"> { let summary = "Insert explicit merge blocks"; let description = [{ diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 437689ea4994..15e7bbe7226a 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_circt_library(CIRCTTransforms HierarchicalRunner.cpp IndexSwitchToIf.cpp InsertMergeBlocks.cpp + ConvertIndexToUInt.cpp MapArithToComb.cpp MaximizeSSA.cpp MemoryBanking.cpp diff --git a/lib/Transforms/ConvertIndexToUInt.cpp b/lib/Transforms/ConvertIndexToUInt.cpp new file mode 100644 index 000000000000..a48e09b9ac50 --- /dev/null +++ b/lib/Transforms/ConvertIndexToUInt.cpp @@ -0,0 +1,154 @@ +//===- ConvertIndexToUInt.cpp - Rewrite index compares to integers --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definitions of the ConvertIndexToUInt pass. +// +//===----------------------------------------------------------------------===// + +#include "circt/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace circt { +#define GEN_PASS_DEF_CONVERTINDEXTOUINT +#include "circt/Transforms/Passes.h.inc" +} // namespace circt + +using namespace mlir; +using namespace circt; + +namespace { + +/// Rewrite `arith.cmpi` operations that still reason about `index` values into +/// pure integer comparisons so that subsequent hardware mappings only observe +/// integer arithmetic. +class IndexCmpToIntegerPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::CmpIOp op, + PatternRewriter &rewriter) const override { + if (!op.getLhs().getType().isIndex()) + return failure(); + + FailureOr targetType = getTargetIntegerType(op); + if (failed(targetType)) + return failure(); + + auto convertOperand = [&](Value operand) -> FailureOr { + if (auto castOp = operand.getDefiningOp()) { + Value source = castOp.getIn(); + auto srcType = dyn_cast(source.getType()); + if (!srcType || srcType != *targetType) + return failure(); + return source; + } + + if (auto constOp = operand.getDefiningOp()) { + if (!constOp.getType().isIndex()) + return failure(); + + auto value = dyn_cast(constOp.getValue()); + if (!value) + return failure(); + + auto attr = rewriter.getIntegerAttr(*targetType, value.getInt()); + auto newConst = + arith::ConstantOp::create(rewriter, constOp.getLoc(), attr); + return newConst.getResult(); + } + + return failure(); + }; + + FailureOr lhs = convertOperand(op.getLhs()); + FailureOr rhs = convertOperand(op.getRhs()); + if (failed(lhs) || failed(rhs)) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getPredicate(), *lhs, + *rhs); + return success(); + } + +private: + static FailureOr getTargetIntegerType(arith::CmpIOp op) { + auto pickType = [](Value operand) -> FailureOr { + if (auto castOp = operand.getDefiningOp()) { + if (auto srcType = dyn_cast(castOp.getIn().getType())) + return srcType; + } + return failure(); + }; + + auto lhsType = pickType(op.getLhs()); + if (succeeded(lhsType)) + return *lhsType; + + auto rhsType = pickType(op.getRhs()); + if (succeeded(rhsType)) + return *rhsType; + + return failure(); + } +}; + +/// Drop `arith.index_cast` that became unused once comparisons were rewritten. +class DropUnusedIndexCastPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::IndexCastOp op, + PatternRewriter &rewriter) const override { + if (!op->use_empty()) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Remove `arith.constant` index definitions that no longer feed any user. +class DropUnusedIndexConstantPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().isIndex() || !op->use_empty()) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertIndexToUIntPass + : public circt::impl::ConvertIndexToUIntBase { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr circt::createConvertIndexToUIntPass() { + return std::make_unique(); +} diff --git a/test/Transforms/convert-index-to-uint.mlir b/test/Transforms/convert-index-to-uint.mlir new file mode 100644 index 000000000000..d082c68ba3a2 --- /dev/null +++ b/test/Transforms/convert-index-to-uint.mlir @@ -0,0 +1,81 @@ +// RUN: circt-opt -split-input-file --switch-to-if --convert-index-to-uint --canonicalize %s | FileCheck %s + +// CHECK-LABEL: func.func @simple_cmp +// CHECK-NOT: arith.index_cast +// CHECK-NOT: : index +// CHECK: %[[C2:.*]] = arith.constant 2 : i4 +// CHECK: %[[CMP0:.*]] = arith.cmpi ult, %arg0, %[[C2]] : i4 +// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %arg1, %[[C2]] : i4 +// CHECK: %[[RES:.*]] = arith.andi %[[CMP0]], %[[CMP1]] : i1 +// CHECK: return %[[RES]] : i1 +module { + func.func @simple_cmp(%arg0: i4, %arg1: i4) -> i1 { + %a = arith.index_cast %arg0 : i4 to index + %c2 = arith.constant 2 : index + %cmp0 = arith.cmpi ult, %a, %c2 : index + %b = arith.index_cast %arg1 : i4 to index + %cmp1 = arith.cmpi eq, %b, %c2 : index + %res = arith.andi %cmp0, %cmp1 : i1 + return %res : i1 + } +} + +// ----- + +// CHECK-LABEL: func.func @single_case +// CHECK-NOT: arith.index_cast +// CHECK: %[[C5:.*]] = arith.constant 5 : i8 +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %arg0, %[[C5]] : i8 +// CHECK: return %[[CMP]] : i1 +module { + func.func @single_case(%cond: i8) -> i1 { + %switch_val = arith.index_cast %cond : i8 to index + %0 = scf.index_switch %switch_val -> i1 + case 5 { + %t = arith.constant true + scf.yield %t : i1 + } + default { + %f = arith.constant false + scf.yield %f : i1 + } + return %0 : i1 + } +} + +// ----- + +// CHECK-LABEL: func.func @multi_case +// CHECK-NOT: arith.index_cast +// CHECK-NOT: : index +// CHECK-DAG: %[[CNEG:.*]] = arith.constant -3 : i3 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i3 +// CHECK: %[[MASK:.*]] = arith.trunci %{{.*}} : i16 to i3 +// CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[MASK]], %[[ZERO]] : i3 +// CHECK: %[[RES:.*]] = scf.if %[[CMP0]] -> (i1) { +// CHECK: } else { +// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[MASK]], %[[CNEG]] : i3 +// CHECK: scf.yield %[[CMP1]] : i1 +// CHECK: } +// CHECK: return %[[RES]] : i1 +module { + func.func @multi_case(%arg0: i16) -> i1 { + %c5_i16 = arith.constant 5 : i16 + %cst_true = arith.constant true + %cst_false = arith.constant false + %shr = arith.shrui %arg0, %c5_i16 : i16 + %mask = arith.trunci %shr : i16 to i3 + %switch_val = arith.index_cast %mask : i3 to index + %0 = scf.index_switch %switch_val -> i1 + case 0 { + scf.yield %cst_true : i1 + } + case 5 { + scf.yield %cst_true : i1 + } + default { + scf.yield %cst_false : i1 + } + return %0 : i1 + } +}