Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/circt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void populateArithToCombPatterns(mlir::RewritePatternSet &patterns,
TypeConverter &typeConverter);

std::unique_ptr<mlir::Pass> createMapArithToCombPass();
std::unique_ptr<mlir::Pass> createConvertIndexToUIntPass();
std::unique_ptr<mlir::Pass> createFlattenMemRefPass();
std::unique_ptr<mlir::Pass> createFlattenMemRefCallsPass();
std::unique_ptr<mlir::Pass> createStripDebugInfoWithPredPass(
Expand Down
14 changes: 14 additions & 0 deletions include/circt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def MapArithToCombPass : Pass<"map-arith-to-comb"> {
let dependentDialects = ["circt::hw::HWDialect, mlir::arith::ArithDialect, circt::comb::CombDialect"];
}

def ConvertIndexToUInt : Pass<"convert-index-to-uint", "::mlir::ModuleOp"> {
let summary = "Rewrite index-based switch comparisons into unsigned integer ops.";
let description = [{
Replace `arith.cmpi` operations whose operands are `index` values (often
produced when lowering `scf.index_switch`) with comparisons over the
original integer type so that downstream hardware mapping passes (e.g.
`--map-arith-to-comb`) do not encounter unsupported index-typed arithmetic.
The pass converts any associated index constants and erases the redundant
casts that become dead afterwards.
Comment on lines +76 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: a concrete example potentially one from the tests would help to understand what transformation is performed?

}];
let constructor = "circt::createConvertIndexToUIntPass()";
let dependentDialects = ["mlir::arith::ArithDialect"];
}

def InsertMergeBlocks : Pass<"insert-merge-blocks", "::mlir::ModuleOp"> {
let summary = "Insert explicit merge blocks";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_circt_library(CIRCTTransforms
HierarchicalRunner.cpp
IndexSwitchToIf.cpp
InsertMergeBlocks.cpp
ConvertIndexToUInt.cpp
MapArithToComb.cpp
MaximizeSSA.cpp
MemoryBanking.cpp
Expand Down
154 changes: 154 additions & 0 deletions lib/Transforms/ConvertIndexToUInt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
//===- ConvertIndexToUInt.cpp - Rewrite index compares to integers --------===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this header format is no longer required - can just use:
//===----------------------------------------------------------------------===//

//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the definitions of the ConvertIndexToUInt pass.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: slightly redundant comment given the filename? I also think these are not entirely necessary if you see for example the datapath dialect files they do not have these comments

//
//===----------------------------------------------------------------------===//

#include "circt/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace circt {
#define GEN_PASS_DEF_CONVERTINDEXTOUINT
#include "circt/Transforms/Passes.h.inc"
} // namespace circt

using namespace mlir;
using namespace circt;

namespace {

/// Rewrite `arith.cmpi` operations that still reason about `index` values into
/// pure integer comparisons so that subsequent hardware mappings only observe
/// integer arithmetic.
class IndexCmpToIntegerPattern : public OpRewritePattern<arith::CmpIOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::CmpIOp op,
PatternRewriter &rewriter) const override {
if (!op.getLhs().getType().isIndex())
return failure();
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: worth also checking Rhs? Or adding an assertion if you get past the lhs check that RHS is index type?


FailureOr<IntegerType> targetType = getTargetIntegerType(op);
if (failed(targetType))
return failure();

auto convertOperand = [&](Value operand) -> FailureOr<Value> {
if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we add a comment explaining what this block is matching and doing?

Value source = castOp.getIn();
auto srcType = dyn_cast<IntegerType>(source.getType());
if (!srcType || srcType != *targetType)
return failure();
return source;
}

if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
if (!constOp.getType().isIndex())
return failure();

auto value = dyn_cast<IntegerAttr>(constOp.getValue());
if (!value)
return failure();

auto attr = rewriter.getIntegerAttr(*targetType, value.getInt());
auto newConst =
arith::ConstantOp::create(rewriter, constOp.getLoc(), attr);
return newConst.getResult();
}

return failure();
};

FailureOr<Value> lhs = convertOperand(op.getLhs());
FailureOr<Value> rhs = convertOperand(op.getRhs());
if (failed(lhs) || failed(rhs))
return failure();

rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), *lhs,
*rhs);
return success();
}

private:
static FailureOr<IntegerType> getTargetIntegerType(arith::CmpIOp op) {
auto pickType = [](Value operand) -> FailureOr<IntegerType> {
if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) {
if (auto srcType = dyn_cast<IntegerType>(castOp.getIn().getType()))
return srcType;
}
return failure();
};

auto lhsType = pickType(op.getLhs());
if (succeeded(lhsType))
return *lhsType;

auto rhsType = pickType(op.getRhs());
if (succeeded(rhsType))
return *rhsType;

return failure();
}
};

/// Drop `arith.index_cast` that became unused once comparisons were rewritten.
class DropUnusedIndexCastPattern
: public OpRewritePattern<arith::IndexCastOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::IndexCastOp op,
PatternRewriter &rewriter) const override {
if (!op->use_empty())
return failure();
rewriter.eraseOp(op);
return success();
}
};

/// Remove `arith.constant` index definitions that no longer feed any user.
class DropUnusedIndexConstantPattern
: public OpRewritePattern<arith::ConstantOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter &rewriter) const override {
if (!op.getType().isIndex() || !op->use_empty())
return failure();
rewriter.eraseOp(op);
return success();
}
};

struct ConvertIndexToUIntPass
: public circt::impl::ConvertIndexToUIntBase<ConvertIndexToUIntPass> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<IndexCmpToIntegerPattern, DropUnusedIndexCastPattern,
DropUnusedIndexConstantPattern>(ctx);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

} // namespace

std::unique_ptr<mlir::Pass> circt::createConvertIndexToUIntPass() {
return std::make_unique<ConvertIndexToUIntPass>();
}
Comment on lines +152 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uenoku - these definitions get automatically generated now right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not automatically generated if there is " let constructor = "circt::createConvertIndexToUIntPass()"; in pass definition so the current implementation looks good to me.

81 changes: 81 additions & 0 deletions test/Transforms/convert-index-to-uint.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: circt-opt -split-input-file --switch-to-if --convert-index-to-uint --canonicalize %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general CIRCT tries to avoid chaining passes in tests - can you split this into two tests (possibly in the same file as is done here integration_test/circt-lec/comb.mlir)? That way we can isolate testing of convert-index-to-uint?


// CHECK-LABEL: func.func @simple_cmp
// CHECK-NOT: arith.index_cast
// CHECK-NOT: : index
// CHECK: %[[C2:.*]] = arith.constant 2 : i4
// CHECK: %[[CMP0:.*]] = arith.cmpi ult, %arg0, %[[C2]] : i4
// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %arg1, %[[C2]] : i4
// CHECK: %[[RES:.*]] = arith.andi %[[CMP0]], %[[CMP1]] : i1
// CHECK: return %[[RES]] : i1
module {
func.func @simple_cmp(%arg0: i4, %arg1: i4) -> i1 {
%a = arith.index_cast %arg0 : i4 to index
%c2 = arith.constant 2 : index
%cmp0 = arith.cmpi ult, %a, %c2 : index
%b = arith.index_cast %arg1 : i4 to index
%cmp1 = arith.cmpi eq, %b, %c2 : index
%res = arith.andi %cmp0, %cmp1 : i1
return %res : i1
}
}

// -----

// CHECK-LABEL: func.func @single_case
// CHECK-NOT: arith.index_cast
// CHECK: %[[C5:.*]] = arith.constant 5 : i8
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %arg0, %[[C5]] : i8
// CHECK: return %[[CMP]] : i1
module {
func.func @single_case(%cond: i8) -> i1 {
%switch_val = arith.index_cast %cond : i8 to index
%0 = scf.index_switch %switch_val -> i1
case 5 {
%t = arith.constant true
scf.yield %t : i1
}
default {
%f = arith.constant false
scf.yield %f : i1
}
return %0 : i1
Comment on lines +33 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once tests are split this will probably need to use different CHECK-*** labels - see examples in circt-lec tests of how to specify multiple filecheck labels

}
}

// -----

// CHECK-LABEL: func.func @multi_case
// CHECK-NOT: arith.index_cast
// CHECK-NOT: : index
// CHECK-DAG: %[[CNEG:.*]] = arith.constant -3 : i3
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i3
// CHECK: %[[MASK:.*]] = arith.trunci %{{.*}} : i16 to i3
// CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[MASK]], %[[ZERO]] : i3
// CHECK: %[[RES:.*]] = scf.if %[[CMP0]] -> (i1) {
// CHECK: } else {
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is there a missing yield line here? for the return of the if branch?

// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[MASK]], %[[CNEG]] : i3
// CHECK: scf.yield %[[CMP1]] : i1
// CHECK: }
// CHECK: return %[[RES]] : i1
module {
func.func @multi_case(%arg0: i16) -> i1 {
%c5_i16 = arith.constant 5 : i16
%cst_true = arith.constant true
%cst_false = arith.constant false
%shr = arith.shrui %arg0, %c5_i16 : i16
%mask = arith.trunci %shr : i16 to i3
%switch_val = arith.index_cast %mask : i3 to index
%0 = scf.index_switch %switch_val -> i1
case 0 {
scf.yield %cst_true : i1
}
case 5 {
scf.yield %cst_true : i1
}
default {
scf.yield %cst_false : i1
}
return %0 : i1
}
}