-
Notifications
You must be signed in to change notification settings - Fork 396
[Transform] Add convert-index-to-uint transform to normalize index compares before comb mapping #9263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Transform] Add convert-index-to-uint transform to normalize index compares before comb mapping #9263
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| //===- ConvertIndexToUInt.cpp - Rewrite index compares to integers --------===// | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: this header format is no longer required - can just use: |
||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Contains the definitions of the ConvertIndexToUInt pass. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: slightly redundant comment given the filename? I also think these are not entirely necessary if you see for example the datapath dialect files they do not have these comments |
||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "circt/Transforms/Passes.h" | ||
|
|
||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/BuiltinOps.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Support/LogicalResult.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| namespace circt { | ||
| #define GEN_PASS_DEF_CONVERTINDEXTOUINT | ||
| #include "circt/Transforms/Passes.h.inc" | ||
| } // namespace circt | ||
|
|
||
| using namespace mlir; | ||
| using namespace circt; | ||
|
|
||
| namespace { | ||
|
|
||
| /// Rewrite `arith.cmpi` operations that still reason about `index` values into | ||
| /// pure integer comparisons so that subsequent hardware mappings only observe | ||
| /// integer arithmetic. | ||
| class IndexCmpToIntegerPattern : public OpRewritePattern<arith::CmpIOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(arith::CmpIOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (!op.getLhs().getType().isIndex()) | ||
| return failure(); | ||
|
Comment on lines
+42
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: worth also checking Rhs? Or adding an assertion if you get past the lhs check that RHS is index type? |
||
|
|
||
| FailureOr<IntegerType> targetType = getTargetIntegerType(op); | ||
| if (failed(targetType)) | ||
| return failure(); | ||
|
|
||
| auto convertOperand = [&](Value operand) -> FailureOr<Value> { | ||
| if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: can we add a comment explaining what this block is matching and doing? |
||
| Value source = castOp.getIn(); | ||
| auto srcType = dyn_cast<IntegerType>(source.getType()); | ||
| if (!srcType || srcType != *targetType) | ||
| return failure(); | ||
| return source; | ||
| } | ||
|
|
||
| if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) { | ||
| if (!constOp.getType().isIndex()) | ||
| return failure(); | ||
|
|
||
| auto value = dyn_cast<IntegerAttr>(constOp.getValue()); | ||
| if (!value) | ||
| return failure(); | ||
|
|
||
| auto attr = rewriter.getIntegerAttr(*targetType, value.getInt()); | ||
| auto newConst = | ||
| arith::ConstantOp::create(rewriter, constOp.getLoc(), attr); | ||
| return newConst.getResult(); | ||
| } | ||
|
|
||
| return failure(); | ||
| }; | ||
|
|
||
| FailureOr<Value> lhs = convertOperand(op.getLhs()); | ||
| FailureOr<Value> rhs = convertOperand(op.getRhs()); | ||
| if (failed(lhs) || failed(rhs)) | ||
| return failure(); | ||
|
|
||
| rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), *lhs, | ||
| *rhs); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| static FailureOr<IntegerType> getTargetIntegerType(arith::CmpIOp op) { | ||
| auto pickType = [](Value operand) -> FailureOr<IntegerType> { | ||
| if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) { | ||
| if (auto srcType = dyn_cast<IntegerType>(castOp.getIn().getType())) | ||
| return srcType; | ||
| } | ||
| return failure(); | ||
| }; | ||
|
|
||
| auto lhsType = pickType(op.getLhs()); | ||
| if (succeeded(lhsType)) | ||
| return *lhsType; | ||
|
|
||
| auto rhsType = pickType(op.getRhs()); | ||
| if (succeeded(rhsType)) | ||
| return *rhsType; | ||
|
|
||
| return failure(); | ||
| } | ||
| }; | ||
|
|
||
| /// Drop `arith.index_cast` that became unused once comparisons were rewritten. | ||
| class DropUnusedIndexCastPattern | ||
| : public OpRewritePattern<arith::IndexCastOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(arith::IndexCastOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (!op->use_empty()) | ||
| return failure(); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| /// Remove `arith.constant` index definitions that no longer feed any user. | ||
| class DropUnusedIndexConstantPattern | ||
| : public OpRewritePattern<arith::ConstantOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(arith::ConstantOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (!op.getType().isIndex() || !op->use_empty()) | ||
| return failure(); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ConvertIndexToUIntPass | ||
| : public circt::impl::ConvertIndexToUIntBase<ConvertIndexToUIntPass> { | ||
| void runOnOperation() override { | ||
| MLIRContext *ctx = &getContext(); | ||
| RewritePatternSet patterns(ctx); | ||
| patterns.add<IndexCmpToIntegerPattern, DropUnusedIndexCastPattern, | ||
| DropUnusedIndexConstantPattern>(ctx); | ||
|
|
||
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) | ||
| signalPassFailure(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| std::unique_ptr<mlir::Pass> circt::createConvertIndexToUIntPass() { | ||
| return std::make_unique<ConvertIndexToUIntPass>(); | ||
| } | ||
|
Comment on lines
+152
to
+154
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @uenoku - these definitions get automatically generated now right?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not automatically generated if there is " |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| // RUN: circt-opt -split-input-file --switch-to-if --convert-index-to-uint --canonicalize %s | FileCheck %s | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general CIRCT tries to avoid chaining passes in tests - can you split this into two tests (possibly in the same file as is done here integration_test/circt-lec/comb.mlir)? That way we can isolate testing of convert-index-to-uint? |
||
|
|
||
| // CHECK-LABEL: func.func @simple_cmp | ||
| // CHECK-NOT: arith.index_cast | ||
| // CHECK-NOT: : index | ||
| // CHECK: %[[C2:.*]] = arith.constant 2 : i4 | ||
| // CHECK: %[[CMP0:.*]] = arith.cmpi ult, %arg0, %[[C2]] : i4 | ||
| // CHECK: %[[CMP1:.*]] = arith.cmpi eq, %arg1, %[[C2]] : i4 | ||
| // CHECK: %[[RES:.*]] = arith.andi %[[CMP0]], %[[CMP1]] : i1 | ||
| // CHECK: return %[[RES]] : i1 | ||
| module { | ||
| func.func @simple_cmp(%arg0: i4, %arg1: i4) -> i1 { | ||
| %a = arith.index_cast %arg0 : i4 to index | ||
| %c2 = arith.constant 2 : index | ||
| %cmp0 = arith.cmpi ult, %a, %c2 : index | ||
| %b = arith.index_cast %arg1 : i4 to index | ||
| %cmp1 = arith.cmpi eq, %b, %c2 : index | ||
| %res = arith.andi %cmp0, %cmp1 : i1 | ||
| return %res : i1 | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @single_case | ||
| // CHECK-NOT: arith.index_cast | ||
| // CHECK: %[[C5:.*]] = arith.constant 5 : i8 | ||
| // CHECK: %[[CMP:.*]] = arith.cmpi eq, %arg0, %[[C5]] : i8 | ||
| // CHECK: return %[[CMP]] : i1 | ||
| module { | ||
| func.func @single_case(%cond: i8) -> i1 { | ||
| %switch_val = arith.index_cast %cond : i8 to index | ||
| %0 = scf.index_switch %switch_val -> i1 | ||
| case 5 { | ||
| %t = arith.constant true | ||
| scf.yield %t : i1 | ||
| } | ||
| default { | ||
| %f = arith.constant false | ||
| scf.yield %f : i1 | ||
| } | ||
| return %0 : i1 | ||
|
Comment on lines
+33
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once tests are split this will probably need to use different CHECK-*** labels - see examples in circt-lec tests of how to specify multiple filecheck labels |
||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @multi_case | ||
| // CHECK-NOT: arith.index_cast | ||
| // CHECK-NOT: : index | ||
| // CHECK-DAG: %[[CNEG:.*]] = arith.constant -3 : i3 | ||
| // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i3 | ||
| // CHECK: %[[MASK:.*]] = arith.trunci %{{.*}} : i16 to i3 | ||
| // CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[MASK]], %[[ZERO]] : i3 | ||
| // CHECK: %[[RES:.*]] = scf.if %[[CMP0]] -> (i1) { | ||
| // CHECK: } else { | ||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Is there a missing yield line here? for the return of the if branch? |
||
| // CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[MASK]], %[[CNEG]] : i3 | ||
| // CHECK: scf.yield %[[CMP1]] : i1 | ||
| // CHECK: } | ||
| // CHECK: return %[[RES]] : i1 | ||
| module { | ||
| func.func @multi_case(%arg0: i16) -> i1 { | ||
| %c5_i16 = arith.constant 5 : i16 | ||
| %cst_true = arith.constant true | ||
| %cst_false = arith.constant false | ||
| %shr = arith.shrui %arg0, %c5_i16 : i16 | ||
| %mask = arith.trunci %shr : i16 to i3 | ||
| %switch_val = arith.index_cast %mask : i3 to index | ||
| %0 = scf.index_switch %switch_val -> i1 | ||
| case 0 { | ||
| scf.yield %cst_true : i1 | ||
| } | ||
| case 5 { | ||
| scf.yield %cst_true : i1 | ||
| } | ||
| default { | ||
| scf.yield %cst_false : i1 | ||
| } | ||
| return %0 : i1 | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: a concrete example potentially one from the tests would help to understand what transformation is performed?