|
14 | 14 | #include "circt/Dialect/Arc/ModelInfo.h" |
15 | 15 | #include "circt/Dialect/Comb/CombOps.h" |
16 | 16 | #include "circt/Dialect/Seq/SeqOps.h" |
| 17 | +#include "circt/Support/ConversionPatternSet.h" |
17 | 18 | #include "circt/Support/Namespace.h" |
18 | 19 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
19 | 20 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
|
22 | 23 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
23 | 24 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
24 | 25 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
25 | | -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
| 26 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
26 | 27 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
27 | 28 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
28 | 29 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
@@ -598,6 +599,53 @@ struct SimEmitValueOpLowering |
598 | 599 |
|
599 | 600 | } // namespace |
600 | 601 |
|
| 602 | +static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, |
| 603 | + ConversionPatternRewriter &rewriter, |
| 604 | + const TypeConverter &converter) { |
| 605 | + // Convert the argument types in the body blocks. |
| 606 | + if (failed(rewriter.convertRegionTypes(&op.getBody(), converter))) |
| 607 | + return failure(); |
| 608 | + |
| 609 | + // Split the block at the current insertion point such that we can branch into |
| 610 | + // the `arc.execute` body region, and have `arc.output` branch back to the |
| 611 | + // point after the `arc.execute`. |
| 612 | + auto *blockBefore = rewriter.getInsertionBlock(); |
| 613 | + auto *blockAfter = |
| 614 | + rewriter.splitBlock(blockBefore, rewriter.getInsertionPoint()); |
| 615 | + |
| 616 | + // Branch to the entry block. |
| 617 | + rewriter.setInsertionPointToEnd(blockBefore); |
| 618 | + mlir::cf::BranchOp::create(rewriter, op.getLoc(), &op.getBody().front(), |
| 619 | + adaptor.getInputs()); |
| 620 | + |
| 621 | + // Make all `arc.output` terminators branch to the block after the |
| 622 | + // `arc.execute` op. |
| 623 | + for (auto &block : op.getBody()) { |
| 624 | + auto outputOp = dyn_cast<arc::OutputOp>(block.getTerminator()); |
| 625 | + if (!outputOp) |
| 626 | + continue; |
| 627 | + rewriter.setInsertionPointToEnd(&block); |
| 628 | + rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(outputOp, blockAfter, |
| 629 | + outputOp.getOperands()); |
| 630 | + } |
| 631 | + |
| 632 | + // Inline the body region between the before and after blocks. |
| 633 | + rewriter.inlineRegionBefore(op.getBody(), blockAfter); |
| 634 | + |
| 635 | + // Add arguments to the block after the `arc.execute`, replace the op's |
| 636 | + // results with the arguments, then perform block signature conversion. |
| 637 | + SmallVector<Value> args; |
| 638 | + args.reserve(op.getNumResults()); |
| 639 | + for (auto result : op.getResults()) |
| 640 | + args.push_back(blockAfter->addArgument(result.getType(), result.getLoc())); |
| 641 | + rewriter.replaceOp(op, args); |
| 642 | + auto conversion = converter.convertBlockSignature(blockAfter); |
| 643 | + if (!conversion) |
| 644 | + return failure(); |
| 645 | + rewriter.applySignatureConversion(blockAfter, *conversion, &converter); |
| 646 | + return success(); |
| 647 | +} |
| 648 | + |
601 | 649 | //===----------------------------------------------------------------------===// |
602 | 650 | // Pass Implementation |
603 | 651 | //===----------------------------------------------------------------------===// |
@@ -667,7 +715,7 @@ void LowerArcToLLVMPass::runOnOperation() { |
667 | 715 | }); |
668 | 716 |
|
669 | 717 | // Setup the conversion patterns. |
670 | | - RewritePatternSet patterns(&getContext()); |
| 718 | + ConversionPatternSet patterns(&getContext(), converter); |
671 | 719 |
|
672 | 720 | // MLIR patterns. |
673 | 721 | populateSCFToControlFlowConversionPatterns(patterns); |
@@ -708,6 +756,7 @@ void LowerArcToLLVMPass::runOnOperation() { |
708 | 756 | ZeroCountOpLowering |
709 | 757 | >(converter, &getContext()); |
710 | 758 | // clang-format on |
| 759 | + patterns.add<ExecuteOp>(convert); |
711 | 760 |
|
712 | 761 | SmallVector<ModelInfo> models; |
713 | 762 | if (failed(collectModels(getOperation(), models))) { |
|
0 commit comments