|
| 1 | +//===- SelectPass.cpp - Select pass code ----------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// SelectPass allows to run multiple different set of passes based on attribute |
| 10 | +// value on some top-level op. |
| 11 | +// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "mlir/Transforms/Passes.h" |
| 15 | + |
| 16 | +#include "mlir/Pass/Pass.h" |
| 17 | +#include "mlir/Pass/PassManager.h" |
| 18 | + |
| 19 | +namespace mlir { |
| 20 | +#define GEN_PASS_DEF_SELECTPASS |
| 21 | +#include "mlir/Transforms/Passes.h.inc" |
| 22 | +} // namespace mlir |
| 23 | + |
| 24 | +using namespace mlir; |
| 25 | + |
| 26 | +namespace { |
| 27 | +struct SelectPass final : public impl::SelectPassBase<SelectPass> { |
| 28 | + using SelectPassBase::SelectPassBase; |
| 29 | + |
| 30 | + SelectPass( |
| 31 | + std::string name_, std::string selectCondName_, |
| 32 | + ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>> |
| 33 | + populateFuncs) { |
| 34 | + name = std::move(name_); |
| 35 | + selectCondName = std::move(selectCondName_); |
| 36 | + |
| 37 | + SmallVector<std::string> selectVals; |
| 38 | + SmallVector<std::string> selectPpls; |
| 39 | + selectVals.reserve(populateFuncs.size()); |
| 40 | + selectPpls.reserve(populateFuncs.size()); |
| 41 | + selectPassManagers.reserve(populateFuncs.size()); |
| 42 | + for (auto &&[name, populate] : populateFuncs) { |
| 43 | + selectVals.emplace_back(name); |
| 44 | + |
| 45 | + auto &pm = selectPassManagers.emplace_back(); |
| 46 | + populate(pm); |
| 47 | + |
| 48 | + llvm::raw_string_ostream os(selectPpls.emplace_back()); |
| 49 | + pm.printAsTextualPipeline(os); |
| 50 | + } |
| 51 | + |
| 52 | + selectValues = selectVals; |
| 53 | + selectPipelines = selectPpls; |
| 54 | + } |
| 55 | + |
| 56 | + LogicalResult initializeOptions( |
| 57 | + StringRef options, |
| 58 | + function_ref<LogicalResult(const Twine &)> errorHandler) override { |
| 59 | + if (failed(SelectPassBase::initializeOptions(options, errorHandler))) |
| 60 | + return failure(); |
| 61 | + |
| 62 | + if (selectCondName.empty()) |
| 63 | + return errorHandler("Invalid select-cond-name"); |
| 64 | + |
| 65 | + if (selectValues.size() != selectPipelines.size()) |
| 66 | + return errorHandler("Values and pipelines size mismatch"); |
| 67 | + |
| 68 | + selectPassManagers.resize(selectPipelines.size()); |
| 69 | + |
| 70 | + for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) { |
| 71 | + if (failed(parsePassPipeline(pipeline, selectPassManagers[i]))) |
| 72 | + return errorHandler("Failed to parse pipeline"); |
| 73 | + } |
| 74 | + |
| 75 | + return success(); |
| 76 | + } |
| 77 | + |
| 78 | + LogicalResult initialize(MLIRContext *context) override { |
| 79 | + condAttrName = StringAttr::get(context, selectCondName); |
| 80 | + |
| 81 | + selectAttrs.reserve(selectAttrs.size()); |
| 82 | + for (StringRef value : selectValues) |
| 83 | + selectAttrs.emplace_back(StringAttr::get(context, value)); |
| 84 | + |
| 85 | + return success(); |
| 86 | + } |
| 87 | + |
| 88 | + void getDependentDialects(DialectRegistry ®istry) const override { |
| 89 | + for (const OpPassManager &pipeline : selectPassManagers) |
| 90 | + pipeline.getDependentDialects(registry); |
| 91 | + } |
| 92 | + |
| 93 | + void runOnOperation() override { |
| 94 | + Operation *op = getOperation(); |
| 95 | + Attribute condAttrValue = op->getAttr(condAttrName); |
| 96 | + if (!condAttrValue) { |
| 97 | + op->emitError("Condition attribute not present: ") << condAttrName; |
| 98 | + return signalPassFailure(); |
| 99 | + } |
| 100 | + |
| 101 | + for (auto &&[value, pm] : |
| 102 | + llvm::zip_equal(selectAttrs, selectPassManagers)) { |
| 103 | + if (value != condAttrValue) |
| 104 | + continue; |
| 105 | + |
| 106 | + if (failed(runPipeline(pm, op))) |
| 107 | + return signalPassFailure(); |
| 108 | + |
| 109 | + return; |
| 110 | + } |
| 111 | + |
| 112 | + op->emitError("Unhandled condition value: ") << condAttrValue; |
| 113 | + return signalPassFailure(); |
| 114 | + } |
| 115 | + |
| 116 | +protected: |
| 117 | + StringRef getName() const override { return name; } |
| 118 | + |
| 119 | +private: |
| 120 | + StringAttr condAttrName; |
| 121 | + SmallVector<Attribute> selectAttrs; |
| 122 | + SmallVector<OpPassManager> selectPassManagers; |
| 123 | +}; |
| 124 | +} // namespace |
| 125 | + |
| 126 | +std::unique_ptr<Pass> mlir::createSelectPass( |
| 127 | + std::string name, std::string selectCondName, |
| 128 | + ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>> |
| 129 | + populateFuncs) { |
| 130 | + return std::make_unique<SelectPass>(std::move(name), |
| 131 | + std::move(selectCondName), populateFuncs); |
| 132 | +} |
0 commit comments