From c4837020e95cbc6effed46db936b80c93e72dc9b Mon Sep 17 00:00:00 2001 From: Schuyler Eldridge Date: Fri, 8 Aug 2025 16:25:29 -0400 Subject: [PATCH 01/14] [FIRRTL] Add InferDomains pass Add a pass that does domain inference and checking. This is used to verify the legality of a FIRRTL circuit with respect to its domains. E.g., this pass is intended to be used for checking for illegal clock domain crossings. Signed-off-by: Schuyler Eldridge --- .../circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h | 22 + include/circt/Dialect/FIRRTL/Passes.td | 13 + include/circt/Firtool/Firtool.h | 8 + include/circt/Support/InstanceGraph.h | 7 + lib/Dialect/FIRRTL/Transforms/CMakeLists.txt | 1 + .../FIRRTL/Transforms/InferDomains.cpp | 1339 +++++++++++++++++ lib/Firtool/Firtool.cpp | 11 +- test/Dialect/FIRRTL/infer-domains-errors.mlir | 128 ++ test/Dialect/FIRRTL/infer-domains.mlir | 250 +++ 9 files changed, 1778 insertions(+), 1 deletion(-) create mode 100644 lib/Dialect/FIRRTL/Transforms/InferDomains.cpp create mode 100644 test/Dialect/FIRRTL/infer-domains-errors.mlir create mode 100644 test/Dialect/FIRRTL/infer-domains.mlir diff --git a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h index e8339dd7320b..44a9903b5520 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h @@ -85,6 +85,28 @@ struct PortInfo { annotations(annos), domains(domains) {} }; +inline bool operator==(const PortInfo &lhs, const PortInfo &rhs) { + if (lhs.name != rhs.name) + return false; + if (lhs.type != rhs.type) + return false; + if (lhs.direction != rhs.direction) + return false; + if (lhs.sym != rhs.sym) + return false; + if (lhs.loc != rhs.loc) + return false; + if (lhs.annotations != rhs.annotations) + return false; + if (lhs.domains != rhs.domains) + return false; + return true; +} + +inline bool operator!=(const PortInfo &lhs, const PortInfo &rhs) { + return !(lhs == rhs); +} + enum class ConnectBehaviorKind { /// Classic FIRRTL connections: last connect 'wins' across paths; /// conditionally applied under 'when'. diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 15cf686d618c..caf8eeb7dffb 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -909,6 +909,19 @@ def CheckLayers : Pass<"firrtl-check-layers", "firrtl::CircuitOp"> { }]; } +def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { + let summary = "Infer and type check all firrtl domains"; + let description = [{ + This pass does domain inference on a FIRRTL circuit. The end result of this + is either a corrrctly domain-checked FIRRTL circuit or failure with verbose + error messages indicating why the FIRRTL circuit has illegal domain + constructs. + + E.g., this pass can be used to check for illegal clock-domain-crossings if + clock domains are specified for signals in the design. + }]; +} + def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { let summary = "lower domain information to properties"; let description = [{ diff --git a/include/circt/Firtool/Firtool.h b/include/circt/Firtool/Firtool.h index 28e5f6d5d897..5069dca326d2 100644 --- a/include/circt/Firtool/Firtool.h +++ b/include/circt/Firtool/Firtool.h @@ -144,6 +144,8 @@ class FirtoolOptions { bool getEmitAllBindFiles() const { return emitAllBindFiles; } + bool shouldInferDomains() const { return inferDomains; } + // Setters, used by the CAPI FirtoolOptions &setOutputFilename(StringRef name) { outputFilename = name; @@ -393,6 +395,11 @@ class FirtoolOptions { return *this; } + FirtoolOptions &setInferDomains(bool value) { + inferDomains = value; + return *this; + } + private: std::string outputFilename; @@ -447,6 +454,7 @@ class FirtoolOptions { bool lintStaticAsserts; bool lintXmrsInDesign; bool emitAllBindFiles; + bool inferDomains; }; void registerFirtoolCLOptions(); diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h index a26cf5ddce77..70fdcf2e67f9 100644 --- a/include/circt/Support/InstanceGraph.h +++ b/include/circt/Support/InstanceGraph.h @@ -59,6 +59,11 @@ class InstanceGraphNode; class InstanceRecord : public llvm::ilist_node_with_parent { public: + /// Get the op that this is tracking. + Operation *getOperation() { + return instance.getOperation(); + } + /// Get the instance-like op that this is tracking. template auto getInstance() { @@ -113,6 +118,8 @@ class InstanceGraphNode : public llvm::ilist_node { public: InstanceGraphNode() : module(nullptr) {} + Operation *getOperation() { return module.getOperation(); } + /// Get the module that this node is tracking. template auto getModule() { diff --git a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt index f8a1ec9e4045..d24f11c7563b 100755 --- a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt +++ b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms GrandCentral.cpp IMConstProp.cpp IMDeadCodeElim.cpp + InferDomains.cpp InferReadWrite.cpp InferResets.cpp InferWidths.cpp diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp new file mode 100644 index 000000000000..1558857221cc --- /dev/null +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -0,0 +1,1339 @@ +//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements FIRRTL domain inference and checking with canonical +// domain representation. Domain sequences are canonicalized by sorting and +// removing duplicates, making domain order irrelevant and allowing duplicate +// domains to be treated as equivalent. The result of this pass is either a +// correctly domain-inferred circuit or pass failure if the circuit contains +// illegal domain crossings. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" +#include "circt/Dialect/FIRRTL/FIRRTLOps.h" +#include "circt/Dialect/FIRRTL/FIRRTLUtils.h" +#include "circt/Dialect/FIRRTL/Passes.h" +#include "circt/Support/Debug.h" +#include "circt/Support/Namespace.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "firrtl-infer-domains" +#undef NDEBUG + +namespace circt { +namespace firrtl { +#define GEN_PASS_DEF_INFERDOMAINS +#include "circt/Dialect/FIRRTL/Passes.h.inc" +} // namespace firrtl +} // namespace circt + +using namespace circt; +using namespace firrtl; + +using InstanceIterator = InstanceGraphNode::UseIterator; +using InstanceRange = llvm::iterator_range; +using PortInsertions = SmallVector>; + +//====-------------------------------------------------------------------------- +// Helpers for working with module or instance domain info. +//====-------------------------------------------------------------------------- + +/// From a domain info attribute, get the domain-type of a domain value at +/// index i. +static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { + if (info.empty()) + return nullptr; + auto ref = cast(info[i]); + return ref.getAttr(); +} + +/// From a domain info attribute, get the row of associated domains for a +/// hardware value at index i. +static auto getPortDomainAssociation(ArrayAttr info, size_t i) { + if (info.empty()) + return info.getAsRange(); + return cast(info[i]).getAsRange(); +} + +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, BlockArgument arg) { + return arg.getOwner()->getParentOp() == module; +} + +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, Value value) { + auto arg = dyn_cast(value); + if (!arg) + return false; + return isPort(module, arg); +} + +//====-------------------------------------------------------------------------- +// Circuit-wide state. +//====-------------------------------------------------------------------------- + +/// Each declared domain in the circuit is assigned an index, based on the order +/// in which it appears. Domain associations for hardware values are represented +/// as a list of domains, sorted by the index of the domain type. +using DomainTypeID = size_t; + +/// Information about the domains in the circuit. Able to map domains to their +/// type ID, which in this pass is the canonical way to reference the type +/// of a domain. +namespace { +struct CircuitDomainInfo { + CircuitDomainInfo(CircuitOp circuit) { processCircuit(circuit); } + + ArrayRef getDomains() const { return domainTable; } + size_t getNumDomains() const { return domainTable.size(); } + DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } + + DomainTypeID getDomainTypeID(DomainOp op) const { + return typeIDTable.at(op.getNameAttr()); + } + + DomainTypeID getDomainTypeID(StringAttr name) const { + return typeIDTable.at(name); + } + + DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const { + return getDomainTypeID(ref.getAttr()); + } + + DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const { + auto name = getDomainPortTypeName(info, i); + return getDomainTypeID(name); + } + + DomainTypeID getDomainTypeID(Value value) const { + assert(isa(value.getType())); + if (auto arg = dyn_cast(value)) { + auto *block = arg.getOwner(); + auto *owner = block->getParentOp(); + auto module = cast(owner); + auto info = module.getDomainInfoAttr(); + auto i = arg.getArgNumber(); + return getDomainTypeID(info, i); + } + + auto result = dyn_cast(value); + auto *owner = result.getOwner(); + auto instance = cast(owner); + auto info = instance.getDomainInfoAttr(); + auto i = result.getResultNumber(); + return getDomainTypeID(info, i); + } + +private: + void processDomain(DomainOp op) { + auto index = domainTable.size(); + auto name = op.getNameAttr(); + domainTable.push_back(op); + typeIDTable.insert({name, index}); + } + + void processCircuit(CircuitOp circuit) { + for (auto decl : circuit.getOps()) + processDomain(decl); + } + + /// A map from domain type ID to op. + SmallVector domainTable; + + /// A map from domain name to type ID. + DenseMap typeIDTable; +}; + +/// Information about the changes made to the interface of a module, which can +/// be replayed onto an instance. +struct ModuleUpdateInfo { + /// The updated domain information for a module. + ArrayAttr portDomainInfo; + /// The domain ports which have been inserted into a module. + PortInsertions portInsertions; +}; + +struct GlobalState { + GlobalState(CircuitOp circuit) : circuitInfo(circuit) {} + + CircuitDomainInfo circuitInfo; + DenseMap moduleUpdateTable; +}; + +} // namespace + +//====-------------------------------------------------------------------------- +// Terms: Syntax for unifying domain and domain-rows. +//====-------------------------------------------------------------------------- + +namespace { + +/// The different sorts of terms in the unification engine. +enum class TermKind { + Variable, + Value, + Row, +}; + +/// A term in the unification engine. +struct Term { + constexpr Term(TermKind kind) : kind(kind) {} + TermKind kind; +}; + +/// Helper to define a term kind. +template +struct TermBase : Term { + static bool classof(const Term *term) { return term->kind == K; } + TermBase() : Term(K) {} +}; + +/// An unknown value. +struct VariableTerm : public TermBase { + VariableTerm() : leader(nullptr) {} + VariableTerm(Term *leader) : leader(leader) {} + Term *leader; +}; + +/// A concrete value defined in the IR. +struct ValueTerm : public TermBase { + ValueTerm(Value value) : value(value) {} + Value getValue() const { return value; } + Value value; +}; + +/// A row of domains. +struct RowTerm : public TermBase { + RowTerm(ArrayRef elements) : elements(elements) {} + ArrayRef elements; +}; + +/// A helper for assigning low numeric IDs to variables for user-facing output. +struct VariableIDTable { + size_t get(VariableTerm *term) { + auto [it, inserted] = table.insert({term, table.size() + 1}); + return it->second; + } + + DenseMap table; +}; + +#ifndef NDEBUG + +raw_ostream &dump(llvm::raw_ostream &out, const Term *term); + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const VariableTerm *term) { + return out << "var@" << (void *)term << "{leader=" << term->leader << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const ValueTerm *term) { + return out << "val@" << term << "{" << term->value << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const RowTerm *term) { + out << "row@" << term << "{"; + bool first = true; + for (auto *element : term->elements) { + if (!first) + out << ", "; + dump(out, element); + first = false; + } + out << "}"; + return out; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const Term *term) { + if (!term) + return out << "null"; + if (auto *var = dyn_cast(term)) + return dump(out, var); + if (auto *val = dyn_cast(term)) + return dump(out, val); + if (auto *row = dyn_cast(term)) + return dump(out, row); + llvm_unreachable("unknown term"); +} +#endif // DEBUG + +// NOLINTNEXTLINE(misc-no-recursion) +Term *find(Term *x) { + if (!x) + return nullptr; + + if (auto *var = dyn_cast(x)) { + if (var->leader == nullptr) + return var; + + auto *leader = find(var->leader); + if (leader != var->leader) + var->leader = leader; + return leader; + } + + return x; +} + +LogicalResult unify(Term *lhs, Term *rhs); + +LogicalResult unify(VariableTerm *x, Term *y) { + x->leader = y; + return success(); +} + +LogicalResult unify(ValueTerm *xv, Term *y) { + if (auto *yv = dyn_cast(y)) { + yv->leader = xv; + return success(); + } + if (auto *yv = dyn_cast(y)) { + return success(xv == yv); + } + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(RowTerm *lhsRow, Term *rhs) { + if (auto *rhsVar = dyn_cast(rhs)) { + rhsVar->leader = lhsRow; + return success(); + } + if (auto *rhsRow = dyn_cast(rhs)) { + assert(lhsRow->elements.size() == rhsRow->elements.size()); + for (auto [x, y] : llvm::zip(lhsRow->elements, rhsRow->elements)) { + if (failed(unify(x, y))) + return failure(); + } + return success(); + } + + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(Term *lhs, Term *rhs) { + LLVM_DEBUG(auto &out = llvm::errs(); out << "unify x="; dump(out, lhs); + out << " y="; dump(out, rhs); out << "\n";); + if (!lhs || !rhs) + return success(); + lhs = find(lhs); + rhs = find(rhs); + if (lhs == rhs) + return success(); + if (auto *lhsVar = dyn_cast(lhs)) + return unify(lhsVar, rhs); + if (auto *lhsVal = dyn_cast(lhs)) + return unify(lhsVal, rhs); + if (auto *lhsRow = dyn_cast(lhs)) + return unify(lhsRow, rhs); + return failure(); +} + +void solve(Term *lhs, Term *rhs) { + auto result = unify(lhs, rhs); + (void)result; + assert(result.succeeded()); +} + +} // namespace + +//====-------------------------------------------------------------------------- +// InferModuleDomains: Primary workhorse for inferring domains on modules. +//====-------------------------------------------------------------------------- + +namespace { +class InferModuleDomains { +public: + /// Run infer-domains on a module. + static LogicalResult run(GlobalState &, FModuleOp); + +private: + /// Initialize module-level state. + InferModuleDomains(GlobalState &); + + /// Execute on the given module. + LogicalResult operator()(FModuleOp); + + /// Record the domain associations of hardware ports, and record the + /// underlying value of output domain ports. + LogicalResult processPorts(FModuleOp); + + /// Record the domain associations of hardware, and record the underlying + /// value of domains, defined within the body of the module. + LogicalResult processBody(FModuleOp); + + /// Record the domain associations of any operands or results, updating the op + /// if necessary. + LogicalResult processOp(Operation *); + LogicalResult processOp(InstanceOp); + LogicalResult processOp(InstanceChoiceOp); + LogicalResult processOp(UnsafeDomainCastOp); + LogicalResult processOp(DomainDefineOp); + + /// Apply the port changes of a module onto an instance-like op. + template + T updateInstancePorts(T op, const ModuleUpdateInfo &update); + + /// Record the domain associations of the ports of an instance-like op. + template + LogicalResult processInstancePorts(T op); + + LogicalResult updateModule(FModuleOp); + + /// Build a table of exported domains: a map from domains defined internally, + /// to their set of aliasing output ports. + void initializeExportTable(FModuleOp); + + /// After generalizing the module, all domains should be solved. Reflect the + /// solved domain associations into the port domain info attribute. + LogicalResult updatePortDomainAssociations(FModuleOp); + + /// After updating the port domain associations, walk the body of the module + /// to fix up any child instance modules. + LogicalResult updateDomainAssociationsInBody(FModuleOp); + LogicalResult updateOpDomainAssociations(Operation *); + + template + LogicalResult updateInstanceDomainAssociations(T op); + + /// Copy the domain associations from the module domain info attribute into a + /// small vector. + SmallVector copyPortDomainAssociations(ArrayAttr, size_t); + + /// Add domain ports for any uninferred domains associated to hardware. + /// Returns the inserted ports, which will be used later to generalize the + /// instances of this module. + void generalizeModule(FModuleOp); + + /// Unify the associated domain rows of two terms. + LogicalResult unifyAssociations(Operation *, Value, Value); + + /// If the domain value is an alias, returns the domain it aliases. + Value getUnderlyingDomain(Value); + + /// Record a mapping from domain in the IR to its corresponding term. + void setTermForDomain(Value, Term *); + + /// Get the corresponding term for a domain in the IR. + Term *getTermForDomain(Value); + + /// Get the corresponding term for a domain in the IR, or null if unset. + Term *getOptTermForDomain(Value) const; + + /// Record a mapping from a hardware value in the IR to a term which + /// represents the row of domains it is associated with. + void setDomainAssociation(Value, Term *); + + /// Get the associated domain row, forced to be at least a row. + RowTerm *getDomainAssociationAsRow(Value); + + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, allocate a variable to stand for + /// the row of domains. + Term *getDomainAssociation(Value); + + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, returns nullptr. + Term *getOptDomainAssociation(Value) const; + + /// Allocate a row, where each domain is a variable. + RowTerm *allocateRow(); + + /// Allocate a row. + RowTerm *allocateRow(ArrayRef); + + /// Allocate a term. + template + T *allocate(Args &&...); + + /// Allocate an array of terms. If any terms were left null, automatically + /// replace them with a new variable. + ArrayRef allocateArray(ArrayRef); + + /// Print a term in a user-friendly way. + void render(Diagnostic &, Term *) const; + void render(Diagnostic &, VariableIDTable &, Term *) const; + + template + void emitPortDomainCrossingError(T, size_t, DomainTypeID, Term *, + Term *) const; + + /// Emit an error when we fail to infer the concrete domain to drive to a + /// domain port. + template + void emitDomainPortInferenceError(T, size_t) const; + + /// Information about the domains in a circuit. + GlobalState &globals; + + /// Term allocator. + llvm::BumpPtrAllocator allocator; + + /// Map from domains in the IR to their underlying term. + DenseMap termTable; + + /// A map from hardware values to their associated row of domains, as a term. + DenseMap associationTable; + + /// A map from local domain definition to its aliasing output ports. + DenseMap> exportTable; +}; +} // namespace + +LogicalResult InferModuleDomains::run(GlobalState &globals, FModuleOp module) { + return InferModuleDomains(globals)(module); +} + +InferModuleDomains::InferModuleDomains(GlobalState &globals) + : globals(globals) {} + +LogicalResult InferModuleDomains::operator()(FModuleOp module) { + LLVM_DEBUG( + llvm::errs() << "================================================\n"; + llvm::errs() << "infer module domains: " << module.getModuleName() + << "\n"; + llvm::errs() << "================================================\n";); + + if (failed(processPorts(module))) + return failure(); + + if (failed(processBody(module))) + return failure(); + + LLVM_DEBUG(for (auto association : associationTable) { + llvm::errs() << "association:\n"; + llvm::errs() << " " << association.first << "\n"; + llvm::errs() << " " << association.second << "\n"; + }); + + return updateModule(module); +} + +LogicalResult InferModuleDomains::processPorts(FModuleOp module) { + auto portDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + + // Process module ports - domain ports define explicit domains. + DenseMap domainTypeIDTable; + for (size_t i = 0; i < numPorts; ++i) { + BlockArgument port = module.getArgument(i); + + // This is a domain port. + if (isa(port.getType())) { + auto typeID = globals.circuitInfo.getDomainTypeID(portDomainInfo, i); + domainTypeIDTable[i] = typeID; + if (module.getPortDirection(i) == Direction::In) { + setTermForDomain(port, allocate(port)); + } + continue; + } + + // This is a port, which may have explicit domain information. + auto portDomains = getPortDomainAssociation(portDomainInfo, i); + if (portDomains.empty()) + continue; + + SmallVector elements(globals.circuitInfo.getNumDomains()); + for (auto domainPortIndexAttr : portDomains) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = domainTypeIDTable[domainPortIndex]; + auto domainValue = module.getArgument(domainPortIndex); + auto *term = getTermForDomain(domainValue); + auto &slot = elements[domainTypeID]; + if (failed(unify(slot, term))) { + emitPortDomainCrossingError(module, i, domainTypeID, slot, term); + return failure(); + } + elements[domainTypeID] = term; + } + auto *row = allocateRow(elements); + setDomainAssociation(port, row); + } + + return success(); +} + +LogicalResult InferModuleDomains::processBody(FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(processOp(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +LogicalResult InferModuleDomains::processOp(Operation *op) { + LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); + + if (auto instance = dyn_cast(op)) + return processOp(instance); + if (auto instance = dyn_cast(op)) + return processOp(instance); + if (auto cast = dyn_cast(op)) + return processOp(cast); + if (auto def = dyn_cast(op)) + return processOp(def); + + // For all other operations (including connections), propagate domains from + // operands to results. This is a conservative approach - all operands and + // results share the same domain associations. + Value lhs; + for (auto rhs : op->getOperands()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(op, lhs, rhs))) + return failure(); + lhs = rhs; + } + for (auto rhs : op->getResults()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(op, lhs, rhs))) + return failure(); + lhs = rhs; + } + return success(); +} + +LogicalResult InferModuleDomains::processOp(InstanceOp op) { + auto module = op.getReferencedModuleNameAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + return processInstancePorts(op); +} + +LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { + auto module = op.getDefaultTargetAttr().getAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + return processInstancePorts(op); +} + +LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(op, op.getInput(), op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto typeID = globals.circuitInfo.getDomainTypeID(domain); + elements[typeID] = getTermForDomain(domain); + } + + auto *row = allocateRow(elements); + setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *srcTerm = getTermForDomain(src); + auto *dstTerm = getTermForDomain(dst); + if (failed(unify(dstTerm, srcTerm))) { + VariableIDTable idTable; + auto diag = op->emitOpError("failed to propagate source to destination"); + auto ¬e1 = diag.attachNote(); + note1 << "destination has underlying value: "; + render(note1, idTable, dstTerm); + + auto ¬e2 = diag.attachNote(src.getLoc()); + note2 << "source has underlying value: "; + render(note2, idTable, srcTerm); + } + return unify(dstTerm, srcTerm); +} + +template +T InferModuleDomains::updateInstancePorts(T op, + const ModuleUpdateInfo &update) { + auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); + clone.setDomainInfoAttr(update.portDomainInfo); + op->erase(); + return clone; +} + +template +LogicalResult InferModuleDomains::processInstancePorts(T op) { + auto circuitInfo = globals.circuitInfo; + auto numDomainTypes = circuitInfo.getNumDomains(); + DenseMap domainPortTypeIDTable; + auto domainInfo = op.getDomainInfoAttr(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + Value port = op.getResult(i); + + LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); + + if (isa(port.getType())) { + auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); + domainPortTypeIDTable[i] = typeID; + if (op.getPortDirection(i) == Direction::Out) { + setTermForDomain(port, allocate(port)); + } + continue; + } + + if (!isa(port.getType())) + continue; + + // This is a port, which may have explicit domain information. Associate the + // port with a row of domains, where each element is derived from the domain + // associations recorded in the domain info attribute of the instance. + SmallVector elements(numDomainTypes); + auto associations = getPortDomainAssociation(domainInfo, i); + for (auto domainPortIndexAttr : associations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto typeID = domainPortTypeIDTable[domainPortIndex]; + auto *term = getTermForDomain(op.getResult(domainPortIndex)); + elements[typeID] = term; + } + + // Confirm that we have complete domain information for the port. We can be + // missing information if, for example, this was an instance of an + // extmodule. + for (size_t domainTypeID = 0; domainTypeID < numDomainTypes; + ++domainTypeID) { + if (elements[domainTypeID]) + continue; + auto domainDecl = circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto portName = op.getPortNameAttr(i); + op->emitOpError() << "missing " << domainName << " association for port " + << portName; + return failure(); + } + + setDomainAssociation(port, allocateRow(elements)); + } + + return success(); +} + +LogicalResult InferModuleDomains::updateModule(FModuleOp op) { + initializeExportTable(op); + + generalizeModule(op); + if (failed(updatePortDomainAssociations(op))) + return failure(); + + if (failed(updateDomainAssociationsInBody(op))) + return failure(); + + return success(); +} + +void InferModuleDomains::initializeExportTable(FModuleOp module) { + size_t numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type)) + continue; + auto value = getUnderlyingDomain(port); + if (value) + exportTable[value].push_back(port); + } +} + +LogicalResult +InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { + // At this point, all domain variables mentioned in ports have been + // solved by generalizing the module (adding input domain ports). Now, we have + // to form the new port domain information for the module by examining the + // the associated domains of each port. + auto *context = module.getContext(); + auto numDomains = globals.circuitInfo.getNumDomains(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + auto oldModuleDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + SmallVector newModuleDomainInfo(numPorts); + + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + // If the port is an output domain, we may need to drive the output with + // a value. If we don't know what value to drive to the port, error. + if (isa(type)) { + if (module.getPortDirection(i) == Direction::Out) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + + // Get the underlying value of the output port. + auto *term = getTermForDomain(port); + term = find(term); + auto *val = dyn_cast(term); + if (!val) { + emitDomainPortInferenceError(module, i); + return failure(); + } + + // If the output port is not driven, drive it. + if (!driven) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + continue; + } + + if (isa(type)) { + auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); + auto *row = getDomainAssociationAsRow(port); + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) + continue; + + auto domain = cast(find(row->elements[domainTypeID]))->value; + auto &exports = exportTable[domain]; + if (exports.empty()) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "private " << domainName << " association for port " + << portName; + diag.attachNote(domain.getLoc()) << "associated domain: " << domain; + return failure(); + } + + if (exports.size() > 1) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "ambiguous " << domainName << " association for port " + << portName; + for (auto arg : exports) { + auto name = module.getPortNameAttr(arg.getArgNumber()); + auto loc = module.getPortLocation(arg.getArgNumber()); + diag.attachNote(loc) << "candidate association " << name; + } + return failure(); + } + + auto argument = cast(exports[0]); + auto domainPortIndex = argument.getArgNumber(); + associations[domainTypeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), + domainPortIndex); + } + + newModuleDomainInfo[i] = ArrayAttr::get(context, associations); + continue; + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + } + + auto newModuleDomainInfoAttr = + ArrayAttr::get(module.getContext(), newModuleDomainInfo); + module.setDomainInfoAttr(newModuleDomainInfoAttr); + + // record the domain info for replaying on instances. + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portDomainInfo = newModuleDomainInfoAttr; + + return success(); +} + +SmallVector +InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, + size_t portIndex) { + SmallVector result(globals.circuitInfo.getNumDomains()); + auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); + for (auto domainPortIndexAttr : oldAssociations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = + globals.circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); + result[domainTypeID] = domainPortIndexAttr; + }; + return result; +} + +void InferModuleDomains::generalizeModule(FModuleOp module) { + PortInsertions insertions; + // If the port is hardware, we have to check the associated row of + // domains. If any associated domain is a variable, we solve the variable + // by generalizing the module with an additional input domain port. If any + // associated domain is defined internally to the module, we have to add + // an output domain port, to allow the domain to escape. + DenseMap pendingSolutions; + llvm::MapVector pendingExports; + + size_t inserted = 0; + auto numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + if (!isa(type)) + continue; + + auto *row = getDomainAssociationAsRow(port); + for (auto [typeID, term] : llvm::enumerate(row->elements)) { + auto *domain = find(term); + + if (auto *val = dyn_cast(domain)) { + auto value = val->value; + // If the domain value is defined inside the module body, we must output + // export the domain, so it may appear in the signature of the + // module. + if (isPort(module, value)) + continue; + + // The domain is defined internally. If there value is already exported, + // or will be exported, we are done. + if (exportTable.contains(value) || pendingExports.contains(value)) + continue; + + // We must insert a new output domain port. + auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::Out; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending export. + auto exportedPortIndex = inserted + portInsertionPoint; + pendingExports[val->value] = exportedPortIndex; + ++inserted; + } + + if (auto *var = dyn_cast(domain)) { + if (pendingSolutions.contains(var)) + continue; + + // insert a new input domain port for the variable. + auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::In; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + pendingSolutions[var] = solutionPortIndex; + ++inserted; + } + } + } + + // Put the domain ports in place. + module.insertPorts(insertions); + + // Solve the variables and record them as "self-exporting". + for (auto [var, portIndex] : pendingSolutions) { + auto port = module.getArgument(portIndex); + auto *solution = allocate(port); + solve(var, solution); + // The port is an export of itself. + exportTable[port].push_back(port); + } + + // Drive the pending exports. + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + for (auto [value, portIndex] : pendingExports) { + auto port = module.getArgument(portIndex); + DomainDefineOp::create(builder, port.getLoc(), port, value); + exportTable[value].push_back(port); + setTermForDomain(port, allocate(value)); + } + + // Record the insertions, so we can replay them on instances later. + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portInsertions = std::move(insertions); +} + +LogicalResult +InferModuleDomains::updateDomainAssociationsInBody(FModuleOp module) { + auto result = success(); + module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { + if (failed(updateOpDomainAssociations(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +LogicalResult InferModuleDomains::updateOpDomainAssociations(Operation *op) { + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); + return success(); +} + +template +LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { + auto *context = op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(op); + auto numPorts = op->getNumResults(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + auto direction = op.getPortDirection(i); + if (isa(type)) { + if (direction == Direction::In) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + if (!driven) { + auto *term = getTermForDomain(port); + term = find(term); + if (auto *val = dyn_cast(term)) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } else { + emitDomainPortInferenceError(op, i); + return failure(); + } + } + } + } + } + return success(); +} + +LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, + Value rhs) { + LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n";); + + if (!lhs || !rhs) + return success(); + + if (lhs == rhs) + return success(); + + auto *lhsTerm = getOptDomainAssociation(lhs); + auto *rhsTerm = getOptDomainAssociation(rhs); + + if (lhsTerm) { + if (rhsTerm) { + if (failed(unify(lhsTerm, rhsTerm))) { + auto diag = op->emitOpError("illegal domain crossing in operation"); + auto ¬e1 = diag.attachNote(lhs.getLoc()); + + note1 << "1st operand has domains: "; + VariableIDTable idTable; + render(note1, idTable, lhsTerm); + + auto ¬e2 = diag.attachNote(rhs.getLoc()); + note2 << "2nd operand has domains: "; + render(note2, idTable, rhsTerm); + + return failure(); + } + } + setDomainAssociation(rhs, lhsTerm); + return success(); + } + + if (rhsTerm) { + setDomainAssociation(lhs, rhsTerm); + return success(); + } + + auto *var = allocate(); + setDomainAssociation(lhs, var); + setDomainAssociation(rhs, var); + return success(); +} + +Value InferModuleDomains::getUnderlyingDomain(Value value) { + assert(isa(value.getType())); + auto *term = getOptTermForDomain(value); + if (auto *val = llvm::dyn_cast_if_present(term)) + return val->value; + return nullptr; +} + +Term *InferModuleDomains::getTermForDomain(Value value) { + assert(isa(value.getType())); + if (auto *term = getOptTermForDomain(value)) + return term; + auto *term = allocate(); + setTermForDomain(value, term); + return term; +} + +Term *InferModuleDomains::getOptTermForDomain(Value value) const { + assert(isa(value.getType())); + auto it = termTable.find(value); + if (it == termTable.end()) + return nullptr; + return find(it->second); +} + +void InferModuleDomains::setTermForDomain(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + assert(!termTable.contains(value)); + termTable.insert({value, term}); +} + +RowTerm *InferModuleDomains::getDomainAssociationAsRow(Value value) { + assert(isa(value.getType())); + auto *term = getOptDomainAssociation(value); + + // If the term is unknown, allocate a fresh row and set the association. + if (!term) { + auto *row = allocateRow(); + setDomainAssociation(value, row); + return row; + } + + // If the term is already a row, return it. + if (auto *row = dyn_cast(term)) + return row; + + // Otherwise, unify the term with a fresh row of domains. + if (auto *var = dyn_cast(term)) { + auto *row = allocateRow(); + solve(var, row); + return row; + } + + assert(false && "unhandled term type"); + return nullptr; +} + +Term *InferModuleDomains::getDomainAssociation(Value value) { + auto *term = getOptDomainAssociation(value); + if (term) + return term; + term = allocate(); + setDomainAssociation(value, term); + return term; +} + +Term *InferModuleDomains::getOptDomainAssociation(Value value) const { + assert(isa(value.getType())); + auto it = associationTable.find(value); + if (it == associationTable.end()) + return nullptr; + return find(it->second); +} + +void InferModuleDomains::setDomainAssociation(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + term = find(term); + associationTable.insert({value, term}); + LLVM_DEBUG(llvm::errs() << " set domain association: " << value; + llvm::errs() << " -> " << term << "\n";); +} + +RowTerm *InferModuleDomains::allocateRow() { + SmallVector elements; + elements.resize(globals.circuitInfo.getNumDomains()); + return allocateRow(elements); +} + +RowTerm *InferModuleDomains::allocateRow(ArrayRef elements) { + auto ds = allocateArray(elements); + return allocate(ds); +} + +template +T *InferModuleDomains::allocate(Args &&...args) { + static_assert(std::is_base_of_v, "T must be a term"); + return new (allocator) T(std::forward(args)...); +} + +ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { + auto size = elements.size(); + if (size == 0) + return {}; + + auto *result = allocator.Allocate(size); + llvm::uninitialized_copy(elements, result); + for (size_t i = 0; i < size; ++i) + if (!result[i]) + result[i] = allocate(); + + return ArrayRef(result, elements.size()); +} + +void InferModuleDomains::render(Diagnostic &out, Term *term) const { + VariableIDTable idTable; + render(out, idTable, term); +} + +// NOLINTNEXTLINE(misc-no-recursion) +void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, + Term *term) const { + term = find(term); + if (auto *var = dyn_cast(term)) { + out << "?" << idTable.get(var); + return; + } + if (auto *val = dyn_cast(term)) { + auto value = val->value; + auto [name, rooted] = getFieldName(FieldRef(value, 0), false); + out << name; + return; + } + if (auto *row = dyn_cast(term)) { + bool first = true; + out << "["; + for (size_t i = 0, e = globals.circuitInfo.getNumDomains(); i < e; ++i) { + auto domainOp = globals.circuitInfo.getDomain(i); + if (!first) { + out << ", "; + first = false; + } + out << domainOp.getName() << ": "; + render(out, idTable, row->elements[i]); + } + out << "]"; + return; + } +} + +template +void InferModuleDomains::emitPortDomainCrossingError(T op, size_t i, + size_t domainTypeID, + Term *term1, + Term *term2) const { + VariableIDTable idTable; + + auto portName = op.getPortNameAttr(i); + auto portLoc = op.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + + auto diag = emitError(portLoc); + diag << "illegal " << domainName << " crossing in port " << portName; + + auto ¬e1 = diag.attachNote(); + note1 << "1st instance: "; + render(note1, idTable, term1); + + auto ¬e2 = diag.attachNote(); + note2 << "2nd instance: "; + render(note2, idTable, term2); +} + +template +void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { + auto name = op.getPortNameAttr(i); + auto diag = emitError(op->getLoc()); + auto info = op.getDomainInfo(); + diag << "unable to infer value for domain port " << name; + for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { + if (auto assocs = dyn_cast(info[j])) { + for (auto assoc : assocs) { + if (i == cast(assoc).getValue()) { + auto name = op.getPortNameAttr(j); + auto loc = op.getPortLocation(j); + diag.attachNote(loc) << "associated with hardware port " << name; + break; + } + } + } + } +} + +//===--------------------------------------------------------------------------- +// InferDomainsPass: Top-level pass implementation. +//===--------------------------------------------------------------------------- + +namespace { +struct InferDomainsPass + : public circt::firrtl::impl::InferDomainsBase { + void runOnOperation() override; +}; +} // namespace + +void InferDomainsPass::runOnOperation() { + LLVM_DEBUG(debugPassHeader(this) << "\n"); + auto circuit = getOperation(); + auto &instanceGraph = getAnalysis(); + + GlobalState globals(circuit); + DenseSet visited; + for (auto *root : instanceGraph) { + for (auto *node : llvm::post_order_ext(root, visited)) { + auto module = dyn_cast(node->getOperation()); + if (!module) + continue; + + if (failed(InferModuleDomains::run(globals, module))) { + signalPassFailure(); + return; + } + } + } + LLVM_DEBUG(debugFooter() << "\n"); +} diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index dcf768817451..c9a99f0d9867 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,6 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); + if (opt.shouldInferDomains()) + pm.nest().addPass(firrtl::createInferDomains()); + return success(); } @@ -758,6 +761,11 @@ struct FirtoolCmdOptions { llvm::cl::desc("Emit bindfiles for private modules"), llvm::cl::init(false)}; + llvm::cl::opt inferDomains{ + "infer-domains", + llvm::cl::desc("Enable domain inference and checking"), + llvm::cl::init(false)}; + //===----------------------------------------------------------------------=== // Lint options //===----------------------------------------------------------------------=== @@ -809,7 +817,7 @@ circt::firtool::FirtoolOptions::FirtoolOptions() disableCSEinClasses(false), selectDefaultInstanceChoice(false), symbolicValueLowering(verif::SymbolicValueLowering::ExtModule), disableWireElimination(false), lintStaticAsserts(true), - lintXmrsInDesign(true), emitAllBindFiles(false) { + lintXmrsInDesign(true), emitAllBindFiles(false), inferDomains(false) { if (!clOptions.isConstructed()) return; outputFilename = clOptions->outputFilename; @@ -862,4 +870,5 @@ circt::firtool::FirtoolOptions::FirtoolOptions() lintStaticAsserts = clOptions->lintStaticAsserts; lintXmrsInDesign = clOptions->lintXmrsInDesign; emitAllBindFiles = clOptions->emitAllBindFiles; + inferDomains = clOptions->inferDomains; } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir new file mode 100644 index 000000000000..da419dcac465 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -0,0 +1,128 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --verify-diagnostics --split-input-file + +// Port annotated with same domain type twice. +firrtl.circuit "DomainCrossOnPort" { + firrtl.domain @ClockDomain + firrtl.module @DomainCrossOnPort( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} + // expected-note @below {{1st instance: A}} + // expected-note @below {{2nd instance: B}} + in %p: !firrtl.uint<1> domains [%A, %B] + ) {} +} + +// ----- + +// Illegal domain crossing - connect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1> + } +} + +// ----- + +// Illegal domain crossing at matchingconnect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// ----- + +// Unable to infer domain of port, when port is driven by constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<1>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> + } +} + +// ----- + +// Unable to infer domain of port, when port is driven by arithmetic on constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<2>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) + firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> + } +} + +// ----- + +// Incomplete extmodule domain information. + +firrtl.circuit "IncompleteDomainInfoForExtModule" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in i: !firrtl.uint<1>) + + firrtl.module @IncompleteDomainInfoForExtModule(in %i: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance' op missing "ClockDomain" association for port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %i : !firrtl.uint<1> + } +} + +// ----- + +// Domain not exported like it should be. + +// ----- + +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir new file mode 100644 index 000000000000..c2cc0d96a83e --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -0,0 +1,250 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s | FileCheck %s + +// Legal domain usage - no crossing. +firrtl.circuit "LegalDomains" { + firrtl.domain @ClockDomain + firrtl.module @LegalDomains( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // Connecting within the same domain is legal. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "LegalDomains" + +// Domain inference through connections. +firrtl.circuit "DomainInference" { + firrtl.domain @ClockDomain + firrtl.module @DomainInference( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> + ) { + %b = firrtl.wire : !firrtl.uint<1> // No explicit domain + + // This should infer that %b is in domain %A. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // This should be legal since %b is now inferred to be in domain %A. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "DomainInference" +// CHECK: out %c: !firrtl.uint<1> domains [%A] + +// Unsafe domain cast +firrtl.circuit "UnsafeDomainCast" { + firrtl.domain @ClockDomain + firrtl.module @UnsafeDomainCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> domains [%B] + ) { + // Unsafe cast from domain A to domain B. + %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> + + // This should be legal since we explicitly cast. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" + +// Domain sequence matching. +firrtl.circuit "LegalSequences" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @LegalSequences( + in %C: !firrtl.domain of @ClockDomain, + in %P: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%C, %P], + out %b: !firrtl.uint<1> domains [%C, %P] + ) { + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence order equivalence - should be legal +firrtl.circuit "SequenceOrderEquivalence" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceOrderEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B, %A] + ) { + // This should be legal since domain order doesn't matter in canonical representation + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "SequenceOrderEquivalence" + +// Domain sequence inference +firrtl.circuit "SequenceInference" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %d: !firrtl.uint<1> + ) { + %c = firrtl.wire : !firrtl.uint<1> + + // %c should infer domain sequence [%A, %B] + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + + // This should be legal since %c has inferred [%A, %B] + firrtl.matchingconnect %d, %c : !firrtl.uint<1> + } +} + +// Domain duplicate equivalence - should be legal. +firrtl.circuit "DuplicateDomainEquivalence" { + firrtl.domain @ClockDomain + firrtl.module @DuplicateDomainEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // This should be legal since duplicate domains are canonicalized. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unsafe domain cast with sequences +firrtl.circuit "UnsafeSequenceCast" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @UnsafeSequenceCast( + in %C1: !firrtl.domain of @ClockDomain, + in %C2: !firrtl.domain of @ClockDomain, + in %P1: !firrtl.domain of @PowerDomain, + in %i: !firrtl.uint<1> domains [%C1, %P1], + out %o: !firrtl.uint<1> domains [%C2, %P1] + ) { + %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> + firrtl.matchingconnect %o, %0 : !firrtl.uint<1> + } +} + +// Different port types domain inference. + +// CHECK-LABEL: DifferentPortTypes +firrtl.circuit "DifferentPortTypes" { + firrtl.domain @ClockDomain + firrtl.module @DifferentPortTypes( + in %A: !firrtl.domain of @ClockDomain, + in %uint_input: !firrtl.uint<8> domains [%A], + in %sint_input: !firrtl.sint<4> domains [%A], + out %uint_output: !firrtl.uint<8>, + out %sint_output: !firrtl.sint<4> + ) { + firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> + firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> + } +} + +// Domain inference through wires. + +// CHECK-LABEL: DomainInferenceThroughWires +firrtl.circuit "DomainInferenceThroughWires" { + firrtl.domain @ClockDomain + firrtl.module @DomainInferenceThroughWires( + in %A: !firrtl.domain of @ClockDomain, + in %input: !firrtl.uint<1> domains [%A], + // CHECK: out %output: !firrtl.uint<1> domains [%A] + out %output: !firrtl.uint<1> + ) { + %wire1 = firrtl.wire : !firrtl.uint<1> + %wire2 = firrtl.wire : !firrtl.uint<1> + + firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> + firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> + firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> + } +} + +// Register inference. + +// CHECK-LABEL: RegisterInference +firrtl.circuit "RegisterInference" { + firrtl.domain @ClockDomain + firrtl.module @RegisterInference( + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + // CHECK: in %d: !firrtl.uint<1> domains [%A] + in %d: !firrtl.uint<1>, + // CHECK: out %q: !firrtl.uint<1> domains [%A] + out %q: !firrtl.uint<1> + ) { + %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> + firrtl.matchingconnect %r, %d : !firrtl.uint<1> + firrtl.matchingconnect %q, %r : !firrtl.uint<1> + } +} + +// Update domain on instance. + +// CHECK-LABEL: InstanceUpdate +firrtl.circuit "InstanceUpdate" { + firrtl.domain @ClockDomain + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceChoiceUpdate +firrtl.circuit "InstanceChoiceUpdate" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + firrtl.module @Bar(in %i : !firrtl.uint<1>) {} + firrtl.module @Baz(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { + %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) + firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: ConstantInMultipleDomains +firrtl.circuit "ConstantInMultipleDomains" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + + firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %x_A, %A + firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> + + %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %y_A, %B + firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> + } +} From 2b9d4cc9a962f7620b289145a23ae93b9111317d Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 21 Oct 2025 13:49:35 -0400 Subject: [PATCH 02/14] InferDomains: add infer-public flag, check interfaces of modules for completeness --- include/circt/Dialect/FIRRTL/Passes.td | 4 + .../FIRRTL/Transforms/InferDomains.cpp | 95 +++++++++++++++++-- test/Dialect/FIRRTL/infer-domains-errors.mlir | 28 ++++-- test/Dialect/FIRRTL/infer-domains.mlir | 29 +++++- 4 files changed, 138 insertions(+), 18 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index caf8eeb7dffb..71e3d54e2fef 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -920,6 +920,10 @@ def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { E.g., this pass can be used to check for illegal clock-domain-crossings if clock domains are specified for signals in the design. }]; + let options = [ + Option<"inferPublic", "infer-public", "bool", "false", + "Infer domains on public modules."> + ]; } def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 1558857221cc..c8b7d3514161 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -352,6 +352,67 @@ void solve(Term *lhs, Term *rhs) { } // namespace +//====-------------------------------------------------------------------------- +// CheckModuleDomains +//====-------------------------------------------------------------------------- + +/// Check that a module has complete domain information. +static LogicalResult checkModuleDomains(GlobalState &globals, + FModuleLike module) { + auto numDomains = globals.circuitInfo.getNumDomains(); + auto domainInfo = module.getDomainInfoAttr(); + DenseMap typeIDTable; + for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { + auto type = module.getPortType(i); + + if (isa(type)) { + auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); + typeIDTable[i] = typeID; + continue; + } + + if (auto baseType = type_dyn_cast(type)) { + SmallVector associations(numDomains); + auto domains = getPortDomainAssociation(domainInfo, i); + for (auto index : domains) { + auto typeID = typeIDTable[index.getUInt()]; + auto &entry = associations[typeID]; + if (entry && entry != index) { + auto domainName = globals.circuitInfo.getDomain(typeID).getNameAttr(); + auto portName = module.getPortNameAttr(i); + auto diag = emitError(module.getPortLocation(i)) + << "ambiguous " << domainName << " association for port " + << portName; + + auto d1Loc = module.getPortLocation(entry.getUInt()); + auto d1Name = module.getPortNameAttr(entry.getUInt()); + diag.attachNote(d1Loc) + << "associated with " << domainName << " port " << d1Name; + + auto d2Loc = module.getPortLocation(index.getUInt()); + auto d2Name = module.getPortNameAttr(index.getUInt()); + diag.attachNote(d2Loc) + << "associated with " << domainName << " port " << d2Name; + } + entry = index; + } + + for (size_t typeID = 0; typeID < numDomains; ++typeID) { + auto association = associations[typeID]; + if (!association) { + auto domainName = globals.circuitInfo.getDomain(typeID).getNameAttr(); + auto portName = module.getPortNameAttr(i); + return emitError(module.getPortLocation(i)) + << "missing " << domainName << " association for port " + << portName; + } + } + } + } + + return success(); +} + //====-------------------------------------------------------------------------- // InferModuleDomains: Primary workhorse for inferring domains on modules. //====-------------------------------------------------------------------------- @@ -525,7 +586,7 @@ LogicalResult InferModuleDomains::operator()(FModuleOp module) { } LogicalResult InferModuleDomains::processPorts(FModuleOp module) { - auto portDomainInfo = module.getDomainInfoAttr(); + auto domainInfo = module.getDomainInfoAttr(); auto numPorts = module.getNumPorts(); // Process module ports - domain ports define explicit domains. @@ -535,7 +596,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { // This is a domain port. if (isa(port.getType())) { - auto typeID = globals.circuitInfo.getDomainTypeID(portDomainInfo, i); + auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); domainTypeIDTable[i] = typeID; if (module.getPortDirection(i) == Direction::In) { setTermForDomain(port, allocate(port)); @@ -544,7 +605,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { } // This is a port, which may have explicit domain information. - auto portDomains = getPortDomainAssociation(portDomainInfo, i); + auto portDomains = getPortDomainAssociation(domainInfo, i); if (portDomains.empty()) continue; @@ -1305,13 +1366,34 @@ void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { } } +static LogicalResult inferModuleDomains(GlobalState &globals, + FModuleOp module) { + return InferModuleDomains::run(globals, module); +} + //===--------------------------------------------------------------------------- // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- +static LogicalResult runOnModuleLike(bool inferPublic, GlobalState &globals, + Operation *op) { + if (auto module = dyn_cast(op)) { + if (module.isPublic() && !inferPublic) + return checkModuleDomains(globals, module); + return inferModuleDomains(globals, module); + } + + if (auto extModule = dyn_cast(op)) { + return checkModuleDomains(globals, extModule); + } + + return success(); +} + namespace { struct InferDomainsPass : public circt::firrtl::impl::InferDomainsBase { + using InferDomainsBase::InferDomainsBase; void runOnOperation() override; }; } // namespace @@ -1320,16 +1402,11 @@ void InferDomainsPass::runOnOperation() { LLVM_DEBUG(debugPassHeader(this) << "\n"); auto circuit = getOperation(); auto &instanceGraph = getAnalysis(); - GlobalState globals(circuit); DenseSet visited; for (auto *root : instanceGraph) { for (auto *node : llvm::post_order_ext(root, visited)) { - auto module = dyn_cast(node->getOperation()); - if (!module) - continue; - - if (failed(InferModuleDomains::run(globals, module))) { + if (failed(runOnModuleLike(inferPublic, globals, node->getModule()))) { signalPassFailure(); return; } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index da419dcac465..d62b13654c41 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --verify-diagnostics --split-input-file +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{infer-public=true}))' %s --verify-diagnostics --split-input-file // Port annotated with same domain type twice. firrtl.circuit "DomainCrossOnPort" { @@ -86,16 +86,28 @@ firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { // Incomplete extmodule domain information. -firrtl.circuit "IncompleteDomainInfoForExtModule" { +firrtl.circuit "Top" { firrtl.domain @ClockDomain - firrtl.extmodule @Foo(in i: !firrtl.uint<1>) + // expected-error @below {{missing "ClockDomain" association for port "i"}} + firrtl.extmodule @Top(in i: !firrtl.uint<1>) +} - firrtl.module @IncompleteDomainInfoForExtModule(in %i: !firrtl.uint<1>) { - // expected-error @below {{'firrtl.instance' op missing "ClockDomain" association for port "i"}} - %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) - firrtl.matchingconnect %foo_i, %i : !firrtl.uint<1> - } +// ----- + +// Conflicting extmodule domain information. + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Top( + // expected-note @below {{associated with "ClockDomain" port "D1"}} + in D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{associated with "ClockDomain" port "D2"}} + in D2 : !firrtl.domain of @ClockDomain, + // expected-error @below {{ambiguous "ClockDomain" association for port "i"}} + in i: !firrtl.uint<1> domains [D1, D2] + ) } // ----- diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index c2cc0d96a83e..45e18aa14472 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s | FileCheck %s +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{infer-public=true}))' %s | FileCheck %s // Legal domain usage - no crossing. firrtl.circuit "LegalDomains" { @@ -248,3 +248,30 @@ firrtl.circuit "ConstantInMultipleDomains" { firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> } } + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + firrtl.extmodule @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.module @Top(in %ClockDomain : !firrtl.domain of @ClockDomain ) { + %foo1_ClockDomain, %foo1_i, %foo1_o = firrtl.instance foo1 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + %foo2_ClockDomain, %foo2_i, %foo2_o = firrtl.instance foo2 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.domain.define %foo1_ClockDomain, %ClockDomain + firrtl.matchingconnect %foo2_i, %foo1_o : !firrtl.uint<1> + firrtl.matchingconnect %foo1_i, %foo2_o : !firrtl.uint<1> + } +} \ No newline at end of file From 749eb86d3aad33ad3fce29bc7e9da065146759eb Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 29 Oct 2025 11:02:26 -0400 Subject: [PATCH 03/14] Add domains option --- include/circt/Dialect/FIRRTL/Passes.h | 20 ++++++ include/circt/Dialect/FIRRTL/Passes.td | 16 +++-- include/circt/Firtool/Firtool.h | 36 ++++++++-- include/circt/Support/InstanceGraph.h | 7 -- .../FIRRTL/Transforms/InferDomains.cpp | 22 ++++-- lib/Firtool/Firtool.cpp | 28 +++++--- test/Dialect/FIRRTL/infer-domains-errors.mlir | 58 +++++++++++++-- test/Dialect/FIRRTL/infer-domains.mlir | 70 +++++++++++++++---- 8 files changed, 210 insertions(+), 47 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/Passes.h b/include/circt/Dialect/FIRRTL/Passes.h index a3cb2b8c9900..0515cbae6456 100644 --- a/include/circt/Dialect/FIRRTL/Passes.h +++ b/include/circt/Dialect/FIRRTL/Passes.h @@ -69,6 +69,26 @@ enum class CompanionMode { Drop, }; +/// The mode for the InferDomains pass. +enum class InferDomainsMode { + /// Check domains with inference for private modules (default). + Infer, + /// Check domains without inference. + Check, + /// Check domains with inference for both public and private modules. + InferAll, +}; + +/// True if the mode indicates we should infer domains on public modules. +constexpr bool shouldInferPublicModules(InferDomainsMode mode) { + return mode == InferDomainsMode::InferAll; +} + +/// True if the mode indicates we should infer domains on private modules. +constexpr bool shouldInferPrivateModules(InferDomainsMode mode) { + return mode == InferDomainsMode::Infer || mode == InferDomainsMode::InferAll; +} + #define GEN_PASS_DECL #include "circt/Dialect/FIRRTL/Passes.h.inc" diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 71e3d54e2fef..3d7a47e84c5b 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -920,10 +920,18 @@ def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { E.g., this pass can be used to check for illegal clock-domain-crossings if clock domains are specified for signals in the design. }]; - let options = [ - Option<"inferPublic", "infer-public", "bool", "false", - "Infer domains on public modules."> - ]; + let options = [Option<"mode", "mode", "InferDomainsMode", + "InferDomainsMode::Infer", "infer, check, infer-all.", + [{ + llvm::cl::values( + clEnumValN(InferDomainsMode::Infer, "infer", + "Check domains with inference for private modules"), + clEnumValN(InferDomainsMode::Check, "check", + "Check domains without inference"), + clEnumValN(InferDomainsMode::InferAll, "infer-all", + "Check domains with inference for both public and private " + "modules")) + }]>]; } def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { diff --git a/include/circt/Firtool/Firtool.h b/include/circt/Firtool/Firtool.h index 5069dca326d2..22d5a335d237 100644 --- a/include/circt/Firtool/Firtool.h +++ b/include/circt/Firtool/Firtool.h @@ -23,6 +23,34 @@ namespace circt { namespace firtool { + +enum class DomainMode { + /// Disable domain checking. + Disable, + /// Check domains with inference for private modules. + Infer, + /// Check domains without inference. + Check, + /// Check domains with inference for both public and private modules. + InferAll, +}; + +/// Convert the "domain mode" firtool option to a "firrtl::InferDomainsMode", +/// the configuration for a pass. +constexpr std::optional +toInferDomainsPassMode(DomainMode mode) { + switch (mode) { + case DomainMode::Disable: + return std::nullopt; + case DomainMode::Infer: + return firrtl::InferDomainsMode::Infer; + case DomainMode::Check: + return firrtl::InferDomainsMode::Check; + case DomainMode::InferAll: + return firrtl::InferDomainsMode::InferAll; + } +} + //===----------------------------------------------------------------------===// // FirtoolOptions //===----------------------------------------------------------------------===// @@ -144,7 +172,7 @@ class FirtoolOptions { bool getEmitAllBindFiles() const { return emitAllBindFiles; } - bool shouldInferDomains() const { return inferDomains; } + DomainMode getDomainMode() const { return domainMode; } // Setters, used by the CAPI FirtoolOptions &setOutputFilename(StringRef name) { @@ -395,8 +423,8 @@ class FirtoolOptions { return *this; } - FirtoolOptions &setInferDomains(bool value) { - inferDomains = value; + FirtoolOptions &setdomainMode(DomainMode value) { + domainMode = value; return *this; } @@ -454,7 +482,7 @@ class FirtoolOptions { bool lintStaticAsserts; bool lintXmrsInDesign; bool emitAllBindFiles; - bool inferDomains; + DomainMode domainMode; }; void registerFirtoolCLOptions(); diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h index 70fdcf2e67f9..a26cf5ddce77 100644 --- a/include/circt/Support/InstanceGraph.h +++ b/include/circt/Support/InstanceGraph.h @@ -59,11 +59,6 @@ class InstanceGraphNode; class InstanceRecord : public llvm::ilist_node_with_parent { public: - /// Get the op that this is tracking. - Operation *getOperation() { - return instance.getOperation(); - } - /// Get the instance-like op that this is tracking. template auto getInstance() { @@ -118,8 +113,6 @@ class InstanceGraphNode : public llvm::ilist_node { public: InstanceGraphNode() : module(nullptr) {} - Operation *getOperation() { return module.getOperation(); } - /// Get the module that this node is tracking. template auto getModule() { diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index c8b7d3514161..5f4ee678d3b5 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -45,6 +45,16 @@ using InstanceIterator = InstanceGraphNode::UseIterator; using InstanceRange = llvm::iterator_range; using PortInsertions = SmallVector>; +//====-------------------------------------------------------------------------- +// Domain Inference mode helper. +//====-------------------------------------------------------------------------- + +template +static bool shouldInfer(T op, InferDomainsMode mode) { + return op.isPublic() ? shouldInferPublicModules(mode) + : shouldInferPrivateModules(mode); +} + //====-------------------------------------------------------------------------- // Helpers for working with module or instance domain info. //====-------------------------------------------------------------------------- @@ -1375,12 +1385,12 @@ static LogicalResult inferModuleDomains(GlobalState &globals, // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- -static LogicalResult runOnModuleLike(bool inferPublic, GlobalState &globals, - Operation *op) { +static LogicalResult runOnModuleLike(InferDomainsMode mode, + GlobalState &globals, Operation *op) { if (auto module = dyn_cast(op)) { - if (module.isPublic() && !inferPublic) - return checkModuleDomains(globals, module); - return inferModuleDomains(globals, module); + if (shouldInfer(module, mode)) + return inferModuleDomains(globals, module); + return checkModuleDomains(globals, module); } if (auto extModule = dyn_cast(op)) { @@ -1406,7 +1416,7 @@ void InferDomainsPass::runOnOperation() { DenseSet visited; for (auto *root : instanceGraph) { for (auto *node : llvm::post_order_ext(root, visited)) { - if (failed(runOnModuleLike(inferPublic, globals, node->getModule()))) { + if (failed(runOnModuleLike(mode, globals, node->getModule()))) { signalPassFailure(); return; } diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index c9a99f0d9867..b636df0a2418 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,9 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); - if (opt.shouldInferDomains()) - pm.nest().addPass(firrtl::createInferDomains()); - + if (auto mode = toInferDomainsPassMode(opt.getDomainMode())) { + pm.nest().addPass(firrtl::createInferDomains({*mode})); + } return success(); } @@ -761,10 +761,19 @@ struct FirtoolCmdOptions { llvm::cl::desc("Emit bindfiles for private modules"), llvm::cl::init(false)}; - llvm::cl::opt inferDomains{ - "infer-domains", - llvm::cl::desc("Enable domain inference and checking"), - llvm::cl::init(false)}; + llvm::cl::opt domainMode{ + "domain-mode", llvm::cl::desc("Enable domain inference and checking"), + llvm::cl::init(firtool::DomainMode::Disable), + llvm::cl::values( + clEnumValN(firtool::DomainMode::Disable, "disable", + "Disable domain checking"), + clEnumValN(firtool::DomainMode::Infer, "infer", + "Check domains with inference for private modules"), + clEnumValN(firtool::DomainMode::Check, "check", + "Check domains without inference"), + clEnumValN(firtool::DomainMode::InferAll, "infer-all", + "Check domains with inference for both public and private " + "modules"))}; //===----------------------------------------------------------------------=== // Lint options @@ -817,7 +826,8 @@ circt::firtool::FirtoolOptions::FirtoolOptions() disableCSEinClasses(false), selectDefaultInstanceChoice(false), symbolicValueLowering(verif::SymbolicValueLowering::ExtModule), disableWireElimination(false), lintStaticAsserts(true), - lintXmrsInDesign(true), emitAllBindFiles(false), inferDomains(false) { + lintXmrsInDesign(true), emitAllBindFiles(false), + domainMode(DomainMode::Disable) { if (!clOptions.isConstructed()) return; outputFilename = clOptions->outputFilename; @@ -870,5 +880,5 @@ circt::firtool::FirtoolOptions::FirtoolOptions() lintStaticAsserts = clOptions->lintStaticAsserts; lintXmrsInDesign = clOptions->lintXmrsInDesign; emitAllBindFiles = clOptions->emitAllBindFiles; - inferDomains = clOptions->inferDomains; + domainMode = clOptions->domainMode; } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index d62b13654c41..1ee4bc726e53 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{infer-public=true}))' %s --verify-diagnostics --split-input-file +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s --verify-diagnostics --split-input-file // Port annotated with same domain type twice. firrtl.circuit "DomainCrossOnPort" { @@ -112,18 +112,38 @@ firrtl.circuit "Top" { // ----- -// Domain not exported like it should be. +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} // ----- -// Domain exported multiple times. Which do we choose? +// Domain exported multiple times, this time with one input and one output. Which do we choose? firrtl.circuit "DoubleExportOfDomain" { - firrtl.domain @ClockDomain + firrtl.domain @ClockDomain firrtl.module @DoubleExportOfDomain( // expected-note @below {{candidate association "DI"}} - in %DI : !firrtl.domain of @ClockDomain, + out %DI : !firrtl.domain of @ClockDomain, // expected-note @below {{candidate association "DO"}} out %DO : !firrtl.domain of @ClockDomain, in %i : !firrtl.uint<1> domains [%DO], @@ -138,3 +158,31 @@ firrtl.circuit "DoubleExportOfDomain" { } } +// ----- + +// InstanceChoice: Each module has different domains inferred. + +firrtl.circuit "ConflictingInstanceChoiceDomains" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + // Foo's "out" port takes on the domains of "in1". + firrtl.module @Foo(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in1 : !firrtl.uint<1> + } + + // Bar's "out" port takes on the domains of "in2". + firrtl.module @Bar(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in2 : !firrtl.uint<1> + } + + firrtl.module @ConflictingInstanceChoiceDomains(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>) { + %inst_in1, %inst_in2, %inst_out = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Foo, @Y -> @Bar } (in in1: !firrtl.uint<1>, in in2: !firrtl.uint<1>, out out: !firrtl.uint<1>) + firrtl.connect %inst_in1, %in1 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %inst_in2, %in2 : !firrtl.uint<1>, !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index 45e18aa14472..1b7d465d34c7 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -1,6 +1,7 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{infer-public=true}))' %s | FileCheck %s +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s | FileCheck %s // Legal domain usage - no crossing. +// CHECK-LABEL: firrtl.circuit "LegalDomains" firrtl.circuit "LegalDomains" { firrtl.domain @ClockDomain firrtl.module @LegalDomains( @@ -12,14 +13,15 @@ firrtl.circuit "LegalDomains" { firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "LegalDomains" // Domain inference through connections. +// CHECK-LABEL: firrtl.circuit "DomainInference" firrtl.circuit "DomainInference" { firrtl.domain @ClockDomain firrtl.module @DomainInference( in %A: !firrtl.domain of @ClockDomain, in %a: !firrtl.uint<1> domains [%A], + // CHECK: out %c: !firrtl.uint<1> domains [%A] out %c: !firrtl.uint<1> ) { %b = firrtl.wire : !firrtl.uint<1> // No explicit domain @@ -31,10 +33,9 @@ firrtl.circuit "DomainInference" { firrtl.matchingconnect %c, %b : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "DomainInference" -// CHECK: out %c: !firrtl.uint<1> domains [%A] // Unsafe domain cast +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" firrtl.circuit "UnsafeDomainCast" { firrtl.domain @ClockDomain firrtl.module @UnsafeDomainCast( @@ -50,9 +51,9 @@ firrtl.circuit "UnsafeDomainCast" { firrtl.matchingconnect %c, %b : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" // Domain sequence matching. +// CHECK-LABEL: firrtl.circuit "LegalSequences" firrtl.circuit "LegalSequences" { firrtl.domain @ClockDomain firrtl.domain @PowerDomain @@ -67,6 +68,7 @@ firrtl.circuit "LegalSequences" { } // Domain sequence order equivalence - should be legal +// CHECK-LABEL: SequenceOrderEquivalence firrtl.circuit "SequenceOrderEquivalence" { firrtl.domain @ClockDomain firrtl.domain @PowerDomain @@ -80,9 +82,9 @@ firrtl.circuit "SequenceOrderEquivalence" { firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "SequenceOrderEquivalence" // Domain sequence inference +// CHECK-LABEL: SequenceInference firrtl.circuit "SequenceInference" { firrtl.domain @ClockDomain firrtl.domain @PowerDomain @@ -103,6 +105,7 @@ firrtl.circuit "SequenceInference" { } // Domain duplicate equivalence - should be legal. +// CHECK-LABEL: DuplicateDomainEquivalence firrtl.circuit "DuplicateDomainEquivalence" { firrtl.domain @ClockDomain firrtl.module @DuplicateDomainEquivalence( @@ -116,6 +119,7 @@ firrtl.circuit "DuplicateDomainEquivalence" { } // Unsafe domain cast with sequences +// CHECK-LABEL: UnsafeSequenceCast firrtl.circuit "UnsafeSequenceCast" { firrtl.domain @ClockDomain firrtl.domain @PowerDomain @@ -133,7 +137,6 @@ firrtl.circuit "UnsafeSequenceCast" { } // Different port types domain inference. - // CHECK-LABEL: DifferentPortTypes firrtl.circuit "DifferentPortTypes" { firrtl.domain @ClockDomain @@ -150,7 +153,6 @@ firrtl.circuit "DifferentPortTypes" { } // Domain inference through wires. - // CHECK-LABEL: DomainInferenceThroughWires firrtl.circuit "DomainInferenceThroughWires" { firrtl.domain @ClockDomain @@ -169,7 +171,53 @@ firrtl.circuit "DomainInferenceThroughWires" { } } -// Register inference. +// Export: add output domain port for domain created internally. +// CHECK-LABEL: ExportDomain +firrtl.circuit "ExportDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ExportDomain( + // CHECK: out %ClockDomain: !firrtl.domain of @ClockDomain + // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + // CHECK: firrtl.domain.define %ClockDomain, %foo_A + } +} + +// Export: Reuse already-exported domain. +// CHECK-LABEL: ReuseExportedDomain +firrtl.circuit "ReuseExportedDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ReuseExportedDomain( + out %A: !firrtl.domain of @ClockDomain, + // CHECK: out %o: !firrtl.uint<1> domains [%A] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + firrtl.domain.define %A, %foo_A + } +} // CHECK-LABEL: RegisterInference firrtl.circuit "RegisterInference" { @@ -188,8 +236,6 @@ firrtl.circuit "RegisterInference" { } } -// Update domain on instance. - // CHECK-LABEL: InstanceUpdate firrtl.circuit "InstanceUpdate" { firrtl.domain @ClockDomain @@ -274,4 +320,4 @@ firrtl.circuit "Top" { firrtl.matchingconnect %foo2_i, %foo1_o : !firrtl.uint<1> firrtl.matchingconnect %foo1_i, %foo2_o : !firrtl.uint<1> } -} \ No newline at end of file +} From 4d852aef4356c51aa0b772c829d72e9471436d6e Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 6 Nov 2025 13:21:08 -0500 Subject: [PATCH 04/14] Fix tests --- test/Dialect/FIRRTL/infer-domains-errors.mlir | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index 1ee4bc726e53..a94f97a67eab 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -136,22 +136,26 @@ firrtl.circuit "DoubleExportOfDomain" { // ----- -// Domain exported multiple times, this time with one input and one output. Which do we choose? +// Domain exported multiple times, this time with two outputs. firrtl.circuit "DoubleExportOfDomain" { - firrtl.domain @ClockDomain + firrtl.domain @ClockDomain + + firrtl.extmodule @Generator(out D: !firrtl.domain of @ClockDomain) firrtl.module @DoubleExportOfDomain( - // expected-note @below {{candidate association "DI"}} - out %DI : !firrtl.domain of @ClockDomain, - // expected-note @below {{candidate association "DO"}} - out %DO : !firrtl.domain of @ClockDomain, - in %i : !firrtl.uint<1> domains [%DO], + // expected-note @below {{candidate association "D1"}} + out %D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "D2"}} + out %D2 : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%D1], // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} out %o : !firrtl.uint<1> domains [] ) { + %gen_D = firrtl.instance gen @Generator(out D: !firrtl.domain of @ClockDomain) // DI and DO are aliases - firrtl.domain.define %DO, %DI + firrtl.domain.define %D1, %gen_D + firrtl.domain.define %D2, %gen_D // o is on same domain as i firrtl.matchingconnect %o, %i : !firrtl.uint<1> @@ -161,6 +165,7 @@ firrtl.circuit "DoubleExportOfDomain" { // ----- // InstanceChoice: Each module has different domains inferred. +// TODO: this just relies on the op-verifier for instance choice ops. firrtl.circuit "ConflictingInstanceChoiceDomains" { firrtl.domain @ClockDomain @@ -176,11 +181,13 @@ firrtl.circuit "ConflictingInstanceChoiceDomains" { } // Bar's "out" port takes on the domains of "in2". + // expected-note @below {{original module declared here}} firrtl.module @Bar(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { firrtl.connect %out, %in2 : !firrtl.uint<1> } firrtl.module @ConflictingInstanceChoiceDomains(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance_choice' op domain info for "out" must be [2 : ui32], but got [0 : ui32]}} %inst_in1, %inst_in2, %inst_out = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Foo, @Y -> @Bar } (in in1: !firrtl.uint<1>, in in2: !firrtl.uint<1>, out out: !firrtl.uint<1>) firrtl.connect %inst_in1, %in1 : !firrtl.uint<1>, !firrtl.uint<1> firrtl.connect %inst_in2, %in2 : !firrtl.uint<1>, !firrtl.uint<1> From c612ee127e85dab8ab7a13ba82990c463cf77f02 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 6 Nov 2025 13:43:06 -0500 Subject: [PATCH 05/14] Review comments --- .../circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h | 22 ------------------- .../FIRRTL/Transforms/InferDomains.cpp | 4 +++- lib/Firtool/Firtool.cpp | 4 ++-- test/Dialect/FIRRTL/infer-domains-errors.mlir | 2 +- 4 files changed, 6 insertions(+), 26 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h index 44a9903b5520..e8339dd7320b 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h @@ -85,28 +85,6 @@ struct PortInfo { annotations(annos), domains(domains) {} }; -inline bool operator==(const PortInfo &lhs, const PortInfo &rhs) { - if (lhs.name != rhs.name) - return false; - if (lhs.type != rhs.type) - return false; - if (lhs.direction != rhs.direction) - return false; - if (lhs.sym != rhs.sym) - return false; - if (lhs.loc != rhs.loc) - return false; - if (lhs.annotations != rhs.annotations) - return false; - if (lhs.domains != rhs.domains) - return false; - return true; -} - -inline bool operator!=(const PortInfo &lhs, const PortInfo &rhs) { - return !(lhs == rhs); -} - enum class ConnectBehaviorKind { /// Classic FIRRTL connections: last connect 'wins' across paths; /// conditionally applied under 'when'. diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 5f4ee678d3b5..b60042c1f84a 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -102,7 +102,8 @@ using DomainTypeID = size_t; /// type ID, which in this pass is the canonical way to reference the type /// of a domain. namespace { -struct CircuitDomainInfo { +class CircuitDomainInfo { +public: CircuitDomainInfo(CircuitOp circuit) { processCircuit(circuit); } ArrayRef getDomains() const { return domainTable; } @@ -302,6 +303,7 @@ Term *find(Term *x) { LogicalResult unify(Term *lhs, Term *rhs); LogicalResult unify(VariableTerm *x, Term *y) { + assert(!x->leader); x->leader = y; return success(); } diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index b636df0a2418..2e01581b914a 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,9 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); - if (auto mode = toInferDomainsPassMode(opt.getDomainMode())) { + if (auto mode = toInferDomainsPassMode(opt.getDomainMode())) pm.nest().addPass(firrtl::createInferDomains({*mode})); - } + return success(); } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index a94f97a67eab..47a018ff6dc8 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -15,7 +15,7 @@ firrtl.circuit "DomainCrossOnPort" { // ----- -// Illegal domain crossing - connect op. +// Illegal domain crossing via connect op. firrtl.circuit "IllegalDomainCrossing" { firrtl.domain @ClockDomain firrtl.module @IllegalDomainCrossing( From bdb6e7b5f78687d9f5297846f3e945b8569e5bd3 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 6 Nov 2025 14:27:23 -0500 Subject: [PATCH 06/14] Update comments --- lib/Dialect/FIRRTL/Transforms/InferDomains.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index b60042c1f84a..895034529307 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -6,12 +6,9 @@ // //===----------------------------------------------------------------------===// // -// This pass implements FIRRTL domain inference and checking with canonical -// domain representation. Domain sequences are canonicalized by sorting and -// removing duplicates, making domain order irrelevant and allowing duplicate -// domains to be treated as equivalent. The result of this pass is either a -// correctly domain-inferred circuit or pass failure if the circuit contains -// illegal domain crossings. +// InferDomains implements FIRRTL domain inference and checking. This pass is a +// bottom-up transform acting on modules. For each module, we ensure there are +// no domain crossings, and we make explicit the domain associations of ports. // //===----------------------------------------------------------------------===// @@ -93,9 +90,10 @@ static bool isPort(FModuleOp module, Value value) { // Circuit-wide state. //====-------------------------------------------------------------------------- -/// Each declared domain in the circuit is assigned an index, based on the order -/// in which it appears. Domain associations for hardware values are represented -/// as a list of domains, sorted by the index of the domain type. +/// Each domain type declared in the circuit is assigned a type-id, based on the +/// order of declaration. Domain associations for hardware values are +/// represented as a list, or row, of domains. The domains in a row are ordered +/// according to their type's id. using DomainTypeID = size_t; /// Information about the domains in the circuit. Able to map domains to their From 9f7930d2c8041db956927845caf79e388977cf5d Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 6 Nov 2025 14:28:42 -0500 Subject: [PATCH 07/14] Remove unused typedefs --- lib/Dialect/FIRRTL/Transforms/InferDomains.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 895034529307..cfcdd0773871 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -38,8 +38,6 @@ namespace firrtl { using namespace circt; using namespace firrtl; -using InstanceIterator = InstanceGraphNode::UseIterator; -using InstanceRange = llvm::iterator_range; using PortInsertions = SmallVector>; //====-------------------------------------------------------------------------- From b77037e5db6e2ac06a98d8510b2daa8e40f5d00e Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 6 Nov 2025 15:06:48 -0500 Subject: [PATCH 08/14] Review comments - instanceGraph.walkPostOrder - circuit debug scoped pass logger --- .../FIRRTL/Transforms/InferDomains.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index cfcdd0773871..8706ceeae330 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -106,9 +106,6 @@ class CircuitDomainInfo { size_t getNumDomains() const { return domainTable.size(); } DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } - DomainTypeID getDomainTypeID(DomainOp op) const { - return typeIDTable.at(op.getNameAttr()); - } DomainTypeID getDomainTypeID(StringAttr name) const { return typeIDTable.at(name); @@ -1407,18 +1404,14 @@ struct InferDomainsPass } // namespace void InferDomainsPass::runOnOperation() { - LLVM_DEBUG(debugPassHeader(this) << "\n"); + CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); auto circuit = getOperation(); auto &instanceGraph = getAnalysis(); GlobalState globals(circuit); DenseSet visited; - for (auto *root : instanceGraph) { - for (auto *node : llvm::post_order_ext(root, visited)) { - if (failed(runOnModuleLike(mode, globals, node->getModule()))) { - signalPassFailure(); - return; - } - } - } - LLVM_DEBUG(debugFooter() << "\n"); + auto result = instanceGraph.walkPostOrder([&](auto &node) { + return runOnModuleLike(mode, globals, node.getModule()); + }); + if (failed(result)) + signalPassFailure(); } From 9e68dfa770683dd8753e6176faf193c9f67b9390 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 7 Nov 2025 09:38:22 -0500 Subject: [PATCH 09/14] Add check-only error tests --- test/Dialect/FIRRTL/infer-domains-check-errors.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 test/Dialect/FIRRTL/infer-domains-check-errors.mlir diff --git a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir new file mode 100644 index 000000000000..26c6b6848924 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir @@ -0,0 +1,13 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=check}))' %s --verify-diagnostics --split-input-file + +// CHECK-LABEL: IncompleteDomainInformation +firrtl.circuit "IncompleteDomainInformation" { + firrtl.domain @ClockDomain + + firrtl.module private @Foo( + // expected-error @below {{missing "ClockDomain" association for port "x"}} + in %x: !firrtl.uint<1> + ) {} + + firrtl.module @IncompleteDomainInformation() {} +} From babf2546d0952163ba81bea88a72fe768a1d95a6 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 7 Nov 2025 09:38:35 -0500 Subject: [PATCH 10/14] Expand checking --- .../FIRRTL/Transforms/InferDomains.cpp | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 8706ceeae330..e230d6ec3e61 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -106,7 +106,6 @@ class CircuitDomainInfo { size_t getNumDomains() const { return domainTable.size(); } DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } - DomainTypeID getDomainTypeID(StringAttr name) const { return typeIDTable.at(name); } @@ -361,9 +360,9 @@ void solve(Term *lhs, Term *rhs) { // CheckModuleDomains //====-------------------------------------------------------------------------- -/// Check that a module has complete domain information. -static LogicalResult checkModuleDomains(GlobalState &globals, - FModuleLike module) { +/// Check that a module has complete domain information for its ports. +static LogicalResult checkPortDomains(GlobalState &globals, + FModuleLike module) { auto numDomains = globals.circuitInfo.getNumDomains(); auto domainInfo = module.getDomainInfoAttr(); DenseMap typeIDTable; @@ -418,6 +417,19 @@ static LogicalResult checkModuleDomains(GlobalState &globals, return success(); } +static LogicalResult checkModuleDomains(GlobalState &globals, + FModuleOp module) { + if (failed(checkPortDomains(globals, module))) + return failure(); + + return success(); +} + +static LogicalResult checkModuleDomains(GlobalState &globals, + FExtModuleOp module) { + return checkPortDomains(globals, module); +} + //====-------------------------------------------------------------------------- // InferModuleDomains: Primary workhorse for inferring domains on modules. //====-------------------------------------------------------------------------- From e3d47b840233bfc90f790c6bbb8c08363d1cec56 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 25 Nov 2025 10:18:20 -0500 Subject: [PATCH 11/14] Fix up infer-domains --- .../FIRRTL/Transforms/InferDomains.cpp | 1621 +++++++++-------- 1 file changed, 826 insertions(+), 795 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index e230d6ec3e61..1617610b6239 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -26,7 +26,6 @@ #include "llvm/Support/Debug.h" #define DEBUG_TYPE "firrtl-infer-domains" -#undef NDEBUG namespace circt { namespace firrtl { @@ -38,25 +37,23 @@ namespace firrtl { using namespace circt; using namespace firrtl; -using PortInsertions = SmallVector>; - //====-------------------------------------------------------------------------- -// Domain Inference mode helper. +// Helpers. //====-------------------------------------------------------------------------- +namespace { + +using PortInsertions = SmallVector>; + template -static bool shouldInfer(T op, InferDomainsMode mode) { +bool shouldInfer(T op, InferDomainsMode mode) { return op.isPublic() ? shouldInferPublicModules(mode) : shouldInferPrivateModules(mode); } -//====-------------------------------------------------------------------------- -// Helpers for working with module or instance domain info. -//====-------------------------------------------------------------------------- - /// From a domain info attribute, get the domain-type of a domain value at /// index i. -static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { +StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { if (info.empty()) return nullptr; auto ref = cast(info[i]); @@ -65,29 +62,33 @@ static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { /// From a domain info attribute, get the row of associated domains for a /// hardware value at index i. -static auto getPortDomainAssociation(ArrayAttr info, size_t i) { +auto getPortDomainAssociation(ArrayAttr info, size_t i) { if (info.empty()) return info.getAsRange(); return cast(info[i]).getAsRange(); } /// Return true if the value is a port on the module. -static bool isPort(FModuleOp module, BlockArgument arg) { +bool isPort(FModuleOp module, BlockArgument arg) { return arg.getOwner()->getParentOp() == module; } /// Return true if the value is a port on the module. -static bool isPort(FModuleOp module, Value value) { +bool isPort(FModuleOp module, Value value) { auto arg = dyn_cast(value); if (!arg) return false; return isPort(module, arg); } +} // namespace + //====-------------------------------------------------------------------------- -// Circuit-wide state. +// Global State. //====-------------------------------------------------------------------------- +namespace { + /// Each domain type declared in the circuit is assigned a type-id, based on the /// order of declaration. Domain associations for hardware values are /// represented as a list, or row, of domains. The domains in a row are ordered @@ -96,11 +97,10 @@ using DomainTypeID = size_t; /// Information about the domains in the circuit. Able to map domains to their /// type ID, which in this pass is the canonical way to reference the type -/// of a domain. -namespace { -class CircuitDomainInfo { +/// of a domain, as well as provide fast access to domain ops +class DomainInfo { public: - CircuitDomainInfo(CircuitOp circuit) { processCircuit(circuit); } + DomainInfo(CircuitOp circuit) { processCircuit(circuit); } ArrayRef getDomains() const { return domainTable; } size_t getNumDomains() const { return domainTable.size(); } @@ -167,12 +167,16 @@ struct ModuleUpdateInfo { PortInsertions portInsertions; }; -struct GlobalState { - GlobalState(CircuitOp circuit) : circuitInfo(circuit) {} +using ModuleUpdateTable = DenseMap; - CircuitDomainInfo circuitInfo; - DenseMap moduleUpdateTable; -}; +/// Apply the port changes of a module onto an instance-like op. +template +T fixInstancePorts(T op, const ModuleUpdateInfo &update) { + auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); + clone.setDomainInfoAttr(update.portDomainInfo); + op->erase(); + return clone; +} } // namespace @@ -222,6 +226,24 @@ struct RowTerm : public TermBase { ArrayRef elements; }; +// NOLINTNEXTLINE(misc-no-recursion) +Term *find(Term *x) { + if (!x) + return nullptr; + + if (auto *var = dyn_cast(x)) { + if (var->leader == nullptr) + return var; + + auto *leader = find(var->leader); + if (leader != var->leader) + var->leader = leader; + return leader; + } + + return x; +} + /// A helper for assigning low numeric IDs to variables for user-facing output. struct VariableIDTable { size_t get(VariableTerm *term) { @@ -232,6 +254,42 @@ struct VariableIDTable { DenseMap table; }; +// NOLINTNEXTLINE(misc-no-recursion) +void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, + Term *term) { + term = find(term); + if (auto *var = dyn_cast(term)) { + out << "?" << idTable.get(var); + return; + } + if (auto *val = dyn_cast(term)) { + auto value = val->value; + auto [name, rooted] = getFieldName(FieldRef(value, 0), false); + out << name; + return; + } + if (auto *row = dyn_cast(term)) { + bool first = true; + out << "["; + for (size_t i = 0, e = info.getNumDomains(); i < e; ++i) { + auto domainOp = info.getDomain(i); + if (!first) { + out << ", "; + first = false; + } + out << domainOp.getName() << ": "; + render(info, out, idTable, row->elements[i]); + } + out << "]"; + return; + } +} + +void render(const DomainInfo &info, Diagnostic &out, Term *term) { + VariableIDTable idTable; + render(info, out, idTable, term); +} + #ifndef NDEBUG raw_ostream &dump(llvm::raw_ostream &out, const Term *term); @@ -274,24 +332,6 @@ raw_ostream &dump(raw_ostream &out, const Term *term) { } #endif // DEBUG -// NOLINTNEXTLINE(misc-no-recursion) -Term *find(Term *x) { - if (!x) - return nullptr; - - if (auto *var = dyn_cast(x)) { - if (var->leader == nullptr) - return var; - - auto *leader = find(var->leader); - if (leader != var->leader) - var->leader = leader; - return leader; - } - - return x; -} - LogicalResult unify(Term *lhs, Term *rhs); LogicalResult unify(VariableTerm *x, Term *y) { @@ -354,427 +394,349 @@ void solve(Term *lhs, Term *rhs) { assert(result.succeeded()); } -} // namespace +class TermAllocator { +public: + /// Allocate a row of fresh domain variables. + RowTerm *allocRow(size_t size) { + SmallVector elements; + elements.resize(size); + return allocRow(elements); + } -//====-------------------------------------------------------------------------- -// CheckModuleDomains -//====-------------------------------------------------------------------------- + /// Allocate a row of terms. + RowTerm *allocRow(ArrayRef elements) { + auto ds = allocArray(elements); + return alloc(ds); + } -/// Check that a module has complete domain information for its ports. -static LogicalResult checkPortDomains(GlobalState &globals, - FModuleLike module) { - auto numDomains = globals.circuitInfo.getNumDomains(); - auto domainInfo = module.getDomainInfoAttr(); - DenseMap typeIDTable; - for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { - auto type = module.getPortType(i); + /// Allocate a fresh variable. + VariableTerm *allocVar() { return alloc(); } - if (isa(type)) { - auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); - typeIDTable[i] = typeID; - continue; - } + /// Allocate a concrete domain. + ValueTerm *allocVal(Value value) { return alloc(value); } - if (auto baseType = type_dyn_cast(type)) { - SmallVector associations(numDomains); - auto domains = getPortDomainAssociation(domainInfo, i); - for (auto index : domains) { - auto typeID = typeIDTable[index.getUInt()]; - auto &entry = associations[typeID]; - if (entry && entry != index) { - auto domainName = globals.circuitInfo.getDomain(typeID).getNameAttr(); - auto portName = module.getPortNameAttr(i); - auto diag = emitError(module.getPortLocation(i)) - << "ambiguous " << domainName << " association for port " - << portName; +private: + template + T *alloc(Args &&...args) { + static_assert(std::is_base_of_v, "T must be a term"); + return new (allocator) T(std::forward(args)...); + } - auto d1Loc = module.getPortLocation(entry.getUInt()); - auto d1Name = module.getPortNameAttr(entry.getUInt()); - diag.attachNote(d1Loc) - << "associated with " << domainName << " port " << d1Name; + ArrayRef allocArray(ArrayRef elements) { + auto size = elements.size(); + if (size == 0) + return {}; - auto d2Loc = module.getPortLocation(index.getUInt()); - auto d2Name = module.getPortNameAttr(index.getUInt()); - diag.attachNote(d2Loc) - << "associated with " << domainName << " port " << d2Name; - } - entry = index; - } + auto *result = allocator.Allocate(size); + llvm::uninitialized_copy(elements, result); + for (size_t i = 0; i < size; ++i) + if (!result[i]) + result[i] = alloc(); - for (size_t typeID = 0; typeID < numDomains; ++typeID) { - auto association = associations[typeID]; - if (!association) { - auto domainName = globals.circuitInfo.getDomain(typeID).getNameAttr(); - auto portName = module.getPortNameAttr(i); - return emitError(module.getPortLocation(i)) - << "missing " << domainName << " association for port " - << portName; - } - } - } + return ArrayRef(result, elements.size()); } - return success(); -} - -static LogicalResult checkModuleDomains(GlobalState &globals, - FModuleOp module) { - if (failed(checkPortDomains(globals, module))) - return failure(); - - return success(); -} + llvm::BumpPtrAllocator allocator; +}; -static LogicalResult checkModuleDomains(GlobalState &globals, - FExtModuleOp module) { - return checkPortDomains(globals, module); -} +} // namespace //====-------------------------------------------------------------------------- -// InferModuleDomains: Primary workhorse for inferring domains on modules. +// DomainTable: A mapping from IR to terms. //====-------------------------------------------------------------------------- namespace { -class InferModuleDomains { -public: - /// Run infer-domains on a module. - static LogicalResult run(GlobalState &, FModuleOp); - -private: - /// Initialize module-level state. - InferModuleDomains(GlobalState &); - - /// Execute on the given module. - LogicalResult operator()(FModuleOp); - - /// Record the domain associations of hardware ports, and record the - /// underlying value of output domain ports. - LogicalResult processPorts(FModuleOp); - - /// Record the domain associations of hardware, and record the underlying - /// value of domains, defined within the body of the module. - LogicalResult processBody(FModuleOp); - - /// Record the domain associations of any operands or results, updating the op - /// if necessary. - LogicalResult processOp(Operation *); - LogicalResult processOp(InstanceOp); - LogicalResult processOp(InstanceChoiceOp); - LogicalResult processOp(UnsafeDomainCastOp); - LogicalResult processOp(DomainDefineOp); - - /// Apply the port changes of a module onto an instance-like op. - template - T updateInstancePorts(T op, const ModuleUpdateInfo &update); - - /// Record the domain associations of the ports of an instance-like op. - template - LogicalResult processInstancePorts(T op); - - LogicalResult updateModule(FModuleOp); - - /// Build a table of exported domains: a map from domains defined internally, - /// to their set of aliasing output ports. - void initializeExportTable(FModuleOp); - - /// After generalizing the module, all domains should be solved. Reflect the - /// solved domain associations into the port domain info attribute. - LogicalResult updatePortDomainAssociations(FModuleOp); - - /// After updating the port domain associations, walk the body of the module - /// to fix up any child instance modules. - LogicalResult updateDomainAssociationsInBody(FModuleOp); - LogicalResult updateOpDomainAssociations(Operation *); - - template - LogicalResult updateInstanceDomainAssociations(T op); - - /// Copy the domain associations from the module domain info attribute into a - /// small vector. - SmallVector copyPortDomainAssociations(ArrayAttr, size_t); - - /// Add domain ports for any uninferred domains associated to hardware. - /// Returns the inserted ports, which will be used later to generalize the - /// instances of this module. - void generalizeModule(FModuleOp); - - /// Unify the associated domain rows of two terms. - LogicalResult unifyAssociations(Operation *, Value, Value); +/// Tracks domain infomation for IR values. +class DomainTable { +public: /// If the domain value is an alias, returns the domain it aliases. - Value getUnderlyingDomain(Value); - - /// Record a mapping from domain in the IR to its corresponding term. - void setTermForDomain(Value, Term *); - - /// Get the corresponding term for a domain in the IR. - Term *getTermForDomain(Value); + Value getOptUnderlyingDomain(Value value) const { + assert(isa(value.getType())); + auto *term = getOptTermForDomain(value); + if (auto *val = llvm::dyn_cast_if_present(term)) + return val->value; + return nullptr; + } /// Get the corresponding term for a domain in the IR, or null if unset. - Term *getOptTermForDomain(Value) const; - - /// Record a mapping from a hardware value in the IR to a term which - /// represents the row of domains it is associated with. - void setDomainAssociation(Value, Term *); + Term *getOptTermForDomain(Value value) const { + assert(isa(value.getType())); + auto it = termTable.find(value); + if (it == termTable.end()) + return nullptr; + return find(it->second); + } - /// Get the associated domain row, forced to be at least a row. - RowTerm *getDomainAssociationAsRow(Value); + /// Get the corresponding term for a domain in the IR. + Term *getTermForDomain(Value value) const { + auto *term = getOptTermForDomain(value); + assert(term); + return term; + } - /// For a hardware value, get the term which represents the row of associated - /// domains. If no mapping has been defined, allocate a variable to stand for - /// the row of domains. - Term *getDomainAssociation(Value); + /// Record a mapping from domain in the IR to its corresponding term. + void setTermForDomain(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + assert(!termTable.contains(value)); + termTable.insert({value, term}); + } /// For a hardware value, get the term which represents the row of associated /// domains. If no mapping has been defined, returns nullptr. - Term *getOptDomainAssociation(Value) const; - - /// Allocate a row, where each domain is a variable. - RowTerm *allocateRow(); - - /// Allocate a row. - RowTerm *allocateRow(ArrayRef); - - /// Allocate a term. - template - T *allocate(Args &&...); - - /// Allocate an array of terms. If any terms were left null, automatically - /// replace them with a new variable. - ArrayRef allocateArray(ArrayRef); - - /// Print a term in a user-friendly way. - void render(Diagnostic &, Term *) const; - void render(Diagnostic &, VariableIDTable &, Term *) const; - - template - void emitPortDomainCrossingError(T, size_t, DomainTypeID, Term *, - Term *) const; - - /// Emit an error when we fail to infer the concrete domain to drive to a - /// domain port. - template - void emitDomainPortInferenceError(T, size_t) const; + Term *getOptDomainAssociation(Value value) const { + assert(isa(value.getType())); + auto it = associationTable.find(value); + if (it == associationTable.end()) + return nullptr; + return find(it->second); + } - /// Information about the domains in a circuit. - GlobalState &globals; + /// For a hardware value, get the term which represents the row of associated + /// domains. + Term *getDomainAssociation(Value value) const { + llvm::errs() << "value = " << value << "\n"; + auto *term = getOptDomainAssociation(value); + assert(term); + return term; + } - /// Term allocator. - llvm::BumpPtrAllocator allocator; + /// Record a mapping from a hardware value in the IR to a term which + /// represents the row of domains it is associated with. + void setDomainAssociation(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + term = find(term); + associationTable.insert({value, term}); + } +private: /// Map from domains in the IR to their underlying term. DenseMap termTable; /// A map from hardware values to their associated row of domains, as a term. DenseMap associationTable; - - /// A map from local domain definition to its aliasing output ports. - DenseMap> exportTable; }; + } // namespace -LogicalResult InferModuleDomains::run(GlobalState &globals, FModuleOp module) { - return InferModuleDomains(globals)(module); +//====-------------------------------------------------------------------------- +// Module processing: solve for the domain associations of hardware. +//====-------------------------------------------------------------------------- + +namespace { + +/// Get the corresponding term for a domain in the IR. If we don't know what the +/// term is, then map the domain in the IR to a variable term. +Term *getTermForDomain(TermAllocator &allocator, DomainTable &table, + Value value) { + assert(isa(value.getType())); + if (auto *term = table.getOptTermForDomain(value)) + return term; + auto *term = allocator.allocVar(); + table.setTermForDomain(value, term); + return term; } -InferModuleDomains::InferModuleDomains(GlobalState &globals) - : globals(globals) {} +/// Get the associated domain row, forced to be at least a row. +/// Get the row of domains that a hardware value in the IR is associated with. +/// If we don't know what the row is, associate the hardware value in the IR to +/// a variable term. +/// For a hardware value, get the term which represents the row of associated +/// domains. If no mapping has been defined, allocate a variable to stand for +/// the row of domains. +Term *getDomainAssociation(TermAllocator &allocator, DomainTable &table, + Value value) { + auto *term = table.getOptDomainAssociation(value); + if (term) + return term; + term = allocator.allocVar(); + table.setDomainAssociation(value, term); + return term; +} -LogicalResult InferModuleDomains::operator()(FModuleOp module) { - LLVM_DEBUG( - llvm::errs() << "================================================\n"; - llvm::errs() << "infer module domains: " << module.getModuleName() - << "\n"; - llvm::errs() << "================================================\n";); +/// Get the row of domains that a hardware value in the IR is associated with. +/// The returned term is forced to be at least a row. +RowTerm *getDomainAssociationAsRow(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + Value value) { + assert(isa(value.getType())); + auto *term = table.getOptDomainAssociation(value); - if (failed(processPorts(module))) - return failure(); + // If the term is unknown, allocate a fresh row and set the association. + if (!term) { + auto *row = allocator.allocRow(info.getNumDomains()); + table.setDomainAssociation(value, row); + return row; + } - if (failed(processBody(module))) - return failure(); + // If the term is already a row, return it. + if (auto *row = dyn_cast(term)) + return row; - LLVM_DEBUG(for (auto association : associationTable) { - llvm::errs() << "association:\n"; - llvm::errs() << " " << association.first << "\n"; - llvm::errs() << " " << association.second << "\n"; - }); + // Otherwise, unify the term with a fresh row of domains. + if (auto *var = dyn_cast(term)) { + auto *row = allocator.allocRow(info.getNumDomains()); + solve(var, row); + return row; + } - return updateModule(module); + assert(false && "unhandled term type"); + return nullptr; } -LogicalResult InferModuleDomains::processPorts(FModuleOp module) { - auto domainInfo = module.getDomainInfoAttr(); - auto numPorts = module.getNumPorts(); +template +void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, + size_t domainTypeID, Term *term1, + Term *term2) { + VariableIDTable idTable; - // Process module ports - domain ports define explicit domains. - DenseMap domainTypeIDTable; - for (size_t i = 0; i < numPorts; ++i) { - BlockArgument port = module.getArgument(i); + auto portName = op.getPortNameAttr(i); + auto portLoc = op.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); - // This is a domain port. - if (isa(port.getType())) { - auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); - domainTypeIDTable[i] = typeID; - if (module.getPortDirection(i) == Direction::In) { - setTermForDomain(port, allocate(port)); - } - continue; - } + auto diag = emitError(portLoc); + diag << "illegal " << domainName << " crossing in port " << portName; - // This is a port, which may have explicit domain information. - auto portDomains = getPortDomainAssociation(domainInfo, i); - if (portDomains.empty()) - continue; + auto ¬e1 = diag.attachNote(); + note1 << "1st instance: "; + render(info, note1, idTable, term1); - SmallVector elements(globals.circuitInfo.getNumDomains()); - for (auto domainPortIndexAttr : portDomains) { - auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto domainTypeID = domainTypeIDTable[domainPortIndex]; - auto domainValue = module.getArgument(domainPortIndex); - auto *term = getTermForDomain(domainValue); - auto &slot = elements[domainTypeID]; - if (failed(unify(slot, term))) { - emitPortDomainCrossingError(module, i, domainTypeID, slot, term); - return failure(); + auto ¬e2 = diag.attachNote(); + note2 << "2nd instance: "; + render(info, note2, idTable, term2); +} + +/// Emit an error when we fail to infer the concrete domain to drive to a +/// domain port. +template +void emitDomainPortInferenceError(T op, size_t i) { + auto name = op.getPortNameAttr(i); + auto diag = emitError(op->getLoc()); + auto info = op.getDomainInfo(); + diag << "unable to infer value for domain port " << name; + for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { + if (auto assocs = dyn_cast(info[j])) { + for (auto assoc : assocs) { + if (i == cast(assoc).getValue()) { + auto name = op.getPortNameAttr(j); + auto loc = op.getPortLocation(j); + diag.attachNote(loc) << "associated with hardware port " << name; + break; + } } - elements[domainTypeID] = term; } - auto *row = allocateRow(elements); - setDomainAssociation(port, row); } - - return success(); } -LogicalResult InferModuleDomains::processBody(FModuleOp module) { - LogicalResult result = success(); - module.getBody().walk([&](Operation *op) -> WalkResult { - if (failed(processOp(op))) { - result = failure(); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return result; -} +/// Unify the associated domain rows of two terms. +LogicalResult unifyAssociations(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + Operation *op, Value lhs, Value rhs) { + LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n";); -LogicalResult InferModuleDomains::processOp(Operation *op) { - LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); + if (!lhs || !rhs) + return success(); - if (auto instance = dyn_cast(op)) - return processOp(instance); - if (auto instance = dyn_cast(op)) - return processOp(instance); - if (auto cast = dyn_cast(op)) - return processOp(cast); - if (auto def = dyn_cast(op)) - return processOp(def); + if (lhs == rhs) + return success(); - // For all other operations (including connections), propagate domains from - // operands to results. This is a conservative approach - all operands and - // results share the same domain associations. - Value lhs; - for (auto rhs : op->getOperands()) { - if (!isa(rhs.getType())) - continue; - if (auto *op = rhs.getDefiningOp(); - op && op->hasTrait()) - continue; - if (failed(unifyAssociations(op, lhs, rhs))) - return failure(); - lhs = rhs; - } - for (auto rhs : op->getResults()) { - if (!isa(rhs.getType())) - continue; - if (auto *op = rhs.getDefiningOp(); - op && op->hasTrait()) - continue; - if (failed(unifyAssociations(op, lhs, rhs))) - return failure(); - lhs = rhs; - } - return success(); -} + auto *lhsTerm = table.getOptDomainAssociation(lhs); + auto *rhsTerm = table.getOptDomainAssociation(rhs); -LogicalResult InferModuleDomains::processOp(InstanceOp op) { - auto module = op.getReferencedModuleNameAttr(); - auto lookup = globals.moduleUpdateTable.find(module); - if (lookup != globals.moduleUpdateTable.end()) - op = updateInstancePorts(op, lookup->second); - return processInstancePorts(op); -} + if (lhsTerm) { + if (rhsTerm) { + if (failed(unify(lhsTerm, rhsTerm))) { + auto diag = op->emitOpError("illegal domain crossing in operation"); + auto ¬e1 = diag.attachNote(lhs.getLoc()); -LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { - auto module = op.getDefaultTargetAttr().getAttr(); - auto lookup = globals.moduleUpdateTable.find(module); - if (lookup != globals.moduleUpdateTable.end()) - op = updateInstancePorts(op, lookup->second); - return processInstancePorts(op); -} + note1 << "1st operand has domains: "; + VariableIDTable idTable; + render(info, note1, idTable, lhsTerm); -LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { - auto domains = op.getDomains(); - if (domains.empty()) - return unifyAssociations(op, op.getInput(), op.getResult()); + auto ¬e2 = diag.attachNote(rhs.getLoc()); + note2 << "2nd operand has domains: "; + render(info, note2, idTable, rhsTerm); - auto input = op.getInput(); - RowTerm *inputRow = getDomainAssociationAsRow(input); - SmallVector elements(inputRow->elements); - for (auto domain : op.getDomains()) { - auto typeID = globals.circuitInfo.getDomainTypeID(domain); - elements[typeID] = getTermForDomain(domain); + return failure(); + } + } + table.setDomainAssociation(rhs, lhsTerm); + return success(); } - auto *row = allocateRow(elements); - setDomainAssociation(op.getResult(), row); + if (rhsTerm) { + table.setDomainAssociation(lhs, rhsTerm); + return success(); + } + + auto *var = allocator.allocVar(); + table.setDomainAssociation(lhs, var); + table.setDomainAssociation(rhs, var); return success(); } -LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { - auto src = op.getSrc(); - auto dst = op.getDest(); - auto *srcTerm = getTermForDomain(src); - auto *dstTerm = getTermForDomain(dst); - if (failed(unify(dstTerm, srcTerm))) { - VariableIDTable idTable; - auto diag = op->emitOpError("failed to propagate source to destination"); - auto ¬e1 = diag.attachNote(); - note1 << "destination has underlying value: "; - render(note1, idTable, dstTerm); +LogicalResult processModulePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + FModuleOp module) { + auto domainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + DenseMap domainTypeIDTable; + for (size_t i = 0; i < numPorts; ++i) { + BlockArgument port = module.getArgument(i); - auto ¬e2 = diag.attachNote(src.getLoc()); - note2 << "source has underlying value: "; - render(note2, idTable, srcTerm); + if (isa(port.getType())) { + auto typeID = info.getDomainTypeID(domainInfo, i); + domainTypeIDTable[i] = typeID; + if (module.getPortDirection(i) == Direction::In) { + table.setTermForDomain(port, allocator.allocVal(port)); + } + continue; + } + + auto portDomains = getPortDomainAssociation(domainInfo, i); + if (portDomains.empty()) + continue; + + SmallVector elements(info.getNumDomains()); + for (auto domainPortIndexAttr : portDomains) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = domainTypeIDTable[domainPortIndex]; + auto domainValue = module.getArgument(domainPortIndex); + auto *term = getTermForDomain(allocator, table, domainValue); + auto &slot = elements[domainTypeID]; + if (failed(unify(slot, term))) { + emitPortDomainCrossingError(info, module, i, domainTypeID, slot, term); + return failure(); + } + elements[domainTypeID] = term; + } + auto *row = allocator.allocRow(elements); + table.setDomainAssociation(port, row); } - return unify(dstTerm, srcTerm); -} -template -T InferModuleDomains::updateInstancePorts(T op, - const ModuleUpdateInfo &update) { - auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); - clone.setDomainInfoAttr(update.portDomainInfo); - op->erase(); - return clone; + return success(); } template -LogicalResult InferModuleDomains::processInstancePorts(T op) { - auto circuitInfo = globals.circuitInfo; - auto numDomainTypes = circuitInfo.getNumDomains(); +LogicalResult processInstancePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + T op) { + auto numDomainTypes = info.getNumDomains(); DenseMap domainPortTypeIDTable; auto domainInfo = op.getDomainInfoAttr(); for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { Value port = op.getResult(i); - LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); - if (isa(port.getType())) { - auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); + auto typeID = info.getDomainTypeID(domainInfo, i); domainPortTypeIDTable[i] = typeID; if (op.getPortDirection(i) == Direction::Out) { - setTermForDomain(port, allocate(port)); + table.setTermForDomain(port, allocator.allocVal(port)); } continue; } @@ -790,7 +752,8 @@ LogicalResult InferModuleDomains::processInstancePorts(T op) { for (auto domainPortIndexAttr : associations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto typeID = domainPortTypeIDTable[domainPortIndex]; - auto *term = getTermForDomain(op.getResult(domainPortIndex)); + auto *term = + getTermForDomain(allocator, table, op.getResult(domainPortIndex)); elements[typeID] = term; } @@ -801,7 +764,7 @@ LogicalResult InferModuleDomains::processInstancePorts(T op) { ++domainTypeID) { if (elements[domainTypeID]) continue; - auto domainDecl = circuitInfo.getDomain(domainTypeID); + auto domainDecl = info.getDomain(domainTypeID); auto domainName = domainDecl.getNameAttr(); auto portName = op.getPortNameAttr(i); op->emitOpError() << "missing " << domainName << " association for port " @@ -809,59 +772,195 @@ LogicalResult InferModuleDomains::processInstancePorts(T op) { return failure(); } - setDomainAssociation(port, allocateRow(elements)); + table.setDomainAssociation(port, allocator.allocRow(elements)); } return success(); } -LogicalResult InferModuleDomains::updateModule(FModuleOp op) { - initializeExportTable(op); +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, InstanceOp op) { + auto module = op.getReferencedModuleNameAttr(); + auto lookup = updateTable.find(module); + if (lookup != updateTable.end()) + op = fixInstancePorts(op, lookup->second); + return processInstancePorts(info, allocator, table, op); +} - generalizeModule(op); - if (failed(updatePortDomainAssociations(op))) - return failure(); +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, + InstanceChoiceOp op) { + auto module = op.getDefaultTargetAttr().getAttr(); + auto lookup = updateTable.find(module); + if (lookup != updateTable.end()) + op = fixInstancePorts(op, lookup->second); + return processInstancePorts(info, allocator, table, op); +} - if (failed(updateDomainAssociationsInBody(op))) - return failure(); +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(info, allocator, table, op, op.getInput(), + op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(info, allocator, table, input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto typeID = info.getDomainTypeID(domain); + elements[typeID] = getTermForDomain(allocator, table, domain); + } + + auto *row = allocator.allocRow(elements); + table.setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *srcTerm = getTermForDomain(allocator, table, src); + auto *dstTerm = getTermForDomain(allocator, table, dst); + if (failed(unify(dstTerm, srcTerm))) { + VariableIDTable idTable; + auto diag = op->emitOpError("failed to propagate source to destination"); + auto ¬e1 = diag.attachNote(); + note1 << "destination has underlying value: "; + render(info, note1, idTable, dstTerm); + + auto ¬e2 = diag.attachNote(src.getLoc()); + note2 << "source has underlying value: "; + render(info, note2, idTable, srcTerm); + } + return success(); +} +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, Operation *op) { + LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); + if (auto instance = dyn_cast(op)) + return processOp(info, allocator, table, updateTable, instance); + if (auto instance = dyn_cast(op)) + return processOp(info, allocator, table, updateTable, instance); + if (auto cast = dyn_cast(op)) + return processOp(info, allocator, table, cast); + if (auto def = dyn_cast(op)) + return processOp(info, allocator, table, def); + + // For all other operations (including connections), propagate domains from + // operands to results. This is a conservative approach - all operands and + // results share the same domain associations. + Value lhs; + for (auto rhs : op->getOperands()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs))) + return failure(); + lhs = rhs; + } + for (auto rhs : op->getResults()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs))) + return failure(); + lhs = rhs; + } return success(); } -void InferModuleDomains::initializeExportTable(FModuleOp module) { +LogicalResult processModuleBody(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + const ModuleUpdateTable &updateTable, + FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(processOp(info, allocator, table, updateTable, op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +/// Populate the domain table by processing the module. If the module has any +/// domain crossing errors, return failure. +LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, + FModuleOp module) { + if (failed(processModulePorts(info, allocator, table, module))) + return failure(); + + return processModuleBody(info, allocator, table, updateTable, module); +} + +} // namespace + +//====-------------------------------------------------------------------------- +// Module updating: write the computed domains back to the IR. +//====-------------------------------------------------------------------------- + +namespace { + +using ExportTable = DenseMap>; + +/// Build a table of exported domains: a map from domains defined internally, +/// to their set of aliasing output ports. +ExportTable initializeExportTable(const DomainTable &table, FModuleOp module) { + ExportTable exports; size_t numPorts = module.getNumPorts(); for (size_t i = 0; i < numPorts; ++i) { auto port = module.getArgument(i); auto type = port.getType(); if (!isa(type)) continue; - auto value = getUnderlyingDomain(port); + auto value = table.getOptUnderlyingDomain(port); if (value) - exportTable[value].push_back(port); + exports[value].push_back(port); } + + return exports; } -LogicalResult -InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { - // At this point, all domain variables mentioned in ports have been - // solved by generalizing the module (adding input domain ports). Now, we have - // to form the new port domain information for the module by examining the - // the associated domains of each port. - auto *context = module.getContext(); - auto numDomains = globals.circuitInfo.getNumDomains(); - auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); - auto oldModuleDomainInfo = module.getDomainInfoAttr(); - auto numPorts = module.getNumPorts(); - SmallVector newModuleDomainInfo(numPorts); +/// Copy the domain associations from the module domain info attribute into a +/// small vector. +SmallVector copyPortDomainAssociations(const DomainInfo &info, + ArrayAttr moduleDomainInfo, + size_t portIndex) { + SmallVector result(info.getNumDomains()); + auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); + for (auto domainPortIndexAttr : oldAssociations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex); + result[domainTypeID] = domainPortIndexAttr; + }; + return result; +} +template +LogicalResult updateInstanceDomainAssociations(const DomainTable &table, T op) { + auto *context = op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(op); + auto numPorts = op->getNumResults(); for (size_t i = 0; i < numPorts; ++i) { - auto port = module.getArgument(i); + auto port = op.getResult(i); auto type = port.getType(); - - // If the port is an output domain, we may need to drive the output with - // a value. If we don't know what value to drive to the port, error. + auto direction = op.getPortDirection(i); if (isa(type)) { - if (module.getPortDirection(i) == Direction::Out) { + if (direction == Direction::In) { bool driven = false; for (auto *user : port.getUsers()) { if (auto connect = dyn_cast(user)) { @@ -871,114 +970,38 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { } } } - - // Get the underlying value of the output port. - auto *term = getTermForDomain(port); - term = find(term); - auto *val = dyn_cast(term); - if (!val) { - emitDomainPortInferenceError(module, i); - return failure(); - } - - // If the output port is not driven, drive it. if (!driven) { - auto loc = port.getLoc(); - auto value = val->value; - DomainDefineOp::create(builder, loc, port, value); - } - } - - newModuleDomainInfo[i] = oldModuleDomainInfo[i]; - continue; - } - - if (isa(type)) { - auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); - auto *row = getDomainAssociationAsRow(port); - for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { - if (associations[domainTypeID]) - continue; - - auto domain = cast(find(row->elements[domainTypeID]))->value; - auto &exports = exportTable[domain]; - if (exports.empty()) { - auto portName = module.getPortNameAttr(i); - auto portLoc = module.getPortLocation(i); - auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); - auto domainName = domainDecl.getNameAttr(); - auto diag = emitError(portLoc) - << "private " << domainName << " association for port " - << portName; - diag.attachNote(domain.getLoc()) << "associated domain: " << domain; - return failure(); - } - - if (exports.size() > 1) { - auto portName = module.getPortNameAttr(i); - auto portLoc = module.getPortLocation(i); - auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); - auto domainName = domainDecl.getNameAttr(); - auto diag = emitError(portLoc) - << "ambiguous " << domainName << " association for port " - << portName; - for (auto arg : exports) { - auto name = module.getPortNameAttr(arg.getArgNumber()); - auto loc = module.getPortLocation(arg.getArgNumber()); - diag.attachNote(loc) << "candidate association " << name; + auto *term = table.getTermForDomain(port); + term = find(term); + if (auto *val = dyn_cast(term)) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } else { + emitDomainPortInferenceError(op, i); + return failure(); } - return failure(); } - - auto argument = cast(exports[0]); - auto domainPortIndex = argument.getArgNumber(); - associations[domainTypeID] = IntegerAttr::get( - IntegerType::get(context, 32, IntegerType::Unsigned), - domainPortIndex); } - - newModuleDomainInfo[i] = ArrayAttr::get(context, associations); - continue; } - - newModuleDomainInfo[i] = oldModuleDomainInfo[i]; } - - auto newModuleDomainInfoAttr = - ArrayAttr::get(module.getContext(), newModuleDomainInfo); - module.setDomainInfoAttr(newModuleDomainInfoAttr); - - // record the domain info for replaying on instances. - auto &update = globals.moduleUpdateTable[module.getNameAttr()]; - update.portDomainInfo = newModuleDomainInfoAttr; - return success(); } -SmallVector -InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, - size_t portIndex) { - SmallVector result(globals.circuitInfo.getNumDomains()); - auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); - for (auto domainPortIndexAttr : oldAssociations) { - auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto domainTypeID = - globals.circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); - result[domainTypeID] = domainPortIndexAttr; - }; - return result; -} - -void InferModuleDomains::generalizeModule(FModuleOp module) { - PortInsertions insertions; - // If the port is hardware, we have to check the associated row of - // domains. If any associated domain is a variable, we solve the variable - // by generalizing the module with an additional input domain port. If any - // associated domain is defined internally to the module, we have to add - // an output domain port, to allow the domain to escape. +/// Add domain ports for any uninferred domains associated to hardware. +/// Returns the inserted ports, which will be used later to generalize the +/// instances of this module. +/// +/// If the port is hardware, we have to check the associated row of +/// domains. If any associated domain is a variable, we solve the variable +/// by generalizing the module with an additional input domain port. If any +/// associated domain is defined internally to the module, we have to add +/// an output domain port, to allow the domain to escape. +void createModuleDomainPorts(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, ExportTable &exportTable, + PortInsertions &insertions, FModuleOp module) { DenseMap pendingSolutions; llvm::MapVector pendingExports; - size_t inserted = 0; auto numPorts = module.getNumPorts(); for (size_t i = 0; i < numPorts; ++i) { @@ -988,7 +1011,7 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { if (!isa(type)) continue; - auto *row = getDomainAssociationAsRow(port); + auto *row = getDomainAssociationAsRow(info, allocator, table, port); for (auto [typeID, term] : llvm::enumerate(row->elements)) { auto *domain = find(term); @@ -1006,7 +1029,7 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { continue; // We must insert a new output domain port. - auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); auto portInsertionPoint = i; @@ -1032,7 +1055,7 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { continue; // insert a new input domain port for the variable. - auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); auto portInsertionPoint = i; @@ -1047,65 +1070,178 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { portAnnos, portDomainInfo); insertions.push_back({portInsertionPoint, portInfo}); - // Record the pending solution. - auto solutionPortIndex = inserted + portInsertionPoint; - pendingSolutions[var] = solutionPortIndex; - ++inserted; + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + pendingSolutions[var] = solutionPortIndex; + ++inserted; + } + } + } + + // Put the domain ports in place. + module.insertPorts(insertions); + + // Solve the variables and record them as "self-exporting". + for (auto [var, portIndex] : pendingSolutions) { + auto port = module.getArgument(portIndex); + auto *solution = allocator.allocVal(port); + solve(var, solution); + exportTable[port].push_back(port); + } + + // Drive the pending exports. + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + for (auto [value, portIndex] : pendingExports) { + auto port = module.getArgument(portIndex); + DomainDefineOp::create(builder, port.getLoc(), port, value); + exportTable[value].push_back(port); + table.setTermForDomain(port, allocator.allocVal(value)); + } +} + +/// After generalizing the module, all domains should be solved. Reflect the +/// solved domain associations into the port domain info attribute. +LogicalResult updateModuleDomainInfo(const DomainInfo &info, + const DomainTable &table, + const ExportTable &exportTable, + ArrayAttr &result, FModuleOp module) { + // At this point, all domain variables mentioned in ports have been + // solved by generalizing the module (adding input domain ports). Now, we have + // to form the new port domain information for the module by examining the + // the associated domains of each port. + auto *context = module.getContext(); + auto numDomains = info.getNumDomains(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + auto oldModuleDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + SmallVector newModuleDomainInfo(numPorts); + + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + // If the port is an output domain, we may need to drive the output with + // a value. If we don't know what value to drive to the port, error. + if (isa(type)) { + if (module.getPortDirection(i) == Direction::Out) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + + // Get the underlying value of the output port. + auto *term = table.getTermForDomain(port); + term = find(term); + auto *val = dyn_cast(term); + if (!val) { + emitDomainPortInferenceError(module, i); + return failure(); + } + + // If the output port is not driven, drive it. + if (!driven) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + continue; + } + + if (isa(type)) { + auto associations = + copyPortDomainAssociations(info, oldModuleDomainInfo, i); + auto *row = cast(table.getDomainAssociation(port)); + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) + continue; + + auto domain = cast(find(row->elements[domainTypeID]))->value; + auto &exports = exportTable.at(domain); + if (exports.empty()) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "private " << domainName << " association for port " + << portName; + diag.attachNote(domain.getLoc()) << "associated domain: " << domain; + return failure(); + } + + if (exports.size() > 1) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "ambiguous " << domainName << " association for port " + << portName; + for (auto arg : exports) { + auto name = module.getPortNameAttr(arg.getArgNumber()); + auto loc = module.getPortLocation(arg.getArgNumber()); + diag.attachNote(loc) << "candidate association " << name; + } + return failure(); + } + + auto argument = cast(exports[0]); + auto domainPortIndex = argument.getArgNumber(); + associations[domainTypeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), + domainPortIndex); } - } - } - // Put the domain ports in place. - module.insertPorts(insertions); - - // Solve the variables and record them as "self-exporting". - for (auto [var, portIndex] : pendingSolutions) { - auto port = module.getArgument(portIndex); - auto *solution = allocate(port); - solve(var, solution); - // The port is an export of itself. - exportTable[port].push_back(port); - } + newModuleDomainInfo[i] = ArrayAttr::get(context, associations); + continue; + } - // Drive the pending exports. - auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); - for (auto [value, portIndex] : pendingExports) { - auto port = module.getArgument(portIndex); - DomainDefineOp::create(builder, port.getLoc(), port, value); - exportTable[value].push_back(port); - setTermForDomain(port, allocate(value)); + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; } - // Record the insertions, so we can replay them on instances later. - auto &update = globals.moduleUpdateTable[module.getNameAttr()]; - update.portInsertions = std::move(insertions); + result = ArrayAttr::get(module.getContext(), newModuleDomainInfo); + module.setDomainInfoAttr(result); + return success(); } -LogicalResult -InferModuleDomains::updateDomainAssociationsInBody(FModuleOp module) { - auto result = success(); - module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { - if (failed(updateOpDomainAssociations(op))) { - result = failure(); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return result; -} +/// Update the ports of the module and record the change in the module update +/// table. +LogicalResult updateModulePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + ModuleUpdateTable &updateTable, FModuleOp op) { + // The export table tracks how domains are exported by the ports of the + // module. Initialize the export table by scanning the current ports. + auto exportTable = initializeExportTable(table, op); -LogicalResult InferModuleDomains::updateOpDomainAssociations(Operation *op) { - if (auto instance = dyn_cast(op)) - return updateInstanceDomainAssociations(instance); - if (auto instance = dyn_cast(op)) - return updateInstanceDomainAssociations(instance); + // Now, create any necessary domain ports. + PortInsertions portInsertions; + createModuleDomainPorts(info, allocator, table, exportTable, portInsertions, + op); + + // Update the domain info for the module's ports. + ArrayAttr portDomainInfo; + if (failed( + updateModuleDomainInfo(info, table, exportTable, portDomainInfo, op))) + return failure(); + + // Record the updated interface change in the update table. + auto &entry = updateTable[op.getModuleNameAttr()]; + entry.portDomainInfo = portDomainInfo; + entry.portInsertions = portInsertions; return success(); } template -LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { - auto *context = op.getContext(); - OpBuilder builder(context); +LogicalResult updateInstance(const DomainTable &table, T op) { + OpBuilder builder(op); builder.setInsertionPointAfter(op); auto numPorts = op->getNumResults(); for (size_t i = 0; i < numPorts; ++i) { @@ -1124,7 +1260,7 @@ LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { } } if (!driven) { - auto *term = getTermForDomain(port); + auto *term = table.getTermForDomain(port); term = find(term); if (auto *val = dyn_cast(term)) { auto loc = port.getLoc(); @@ -1141,289 +1277,184 @@ LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { return success(); } -LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, - Value rhs) { - LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; - llvm::errs() << " lhs=" << lhs << "\n"; - llvm::errs() << " rhs=" << rhs << "\n";); - - if (!lhs || !rhs) - return success(); - - if (lhs == rhs) - return success(); - - auto *lhsTerm = getOptDomainAssociation(lhs); - auto *rhsTerm = getOptDomainAssociation(rhs); - - if (lhsTerm) { - if (rhsTerm) { - if (failed(unify(lhsTerm, rhsTerm))) { - auto diag = op->emitOpError("illegal domain crossing in operation"); - auto ¬e1 = diag.attachNote(lhs.getLoc()); - - note1 << "1st operand has domains: "; - VariableIDTable idTable; - render(note1, idTable, lhsTerm); - - auto ¬e2 = diag.attachNote(rhs.getLoc()); - note2 << "2nd operand has domains: "; - render(note2, idTable, rhsTerm); - - return failure(); - } - } - setDomainAssociation(rhs, lhsTerm); - return success(); - } - - if (rhsTerm) { - setDomainAssociation(lhs, rhsTerm); - return success(); - } - - auto *var = allocate(); - setDomainAssociation(lhs, var); - setDomainAssociation(rhs, var); +LogicalResult updateOp(const DomainTable &table, Operation *op) { + if (auto instance = dyn_cast(op)) + return updateInstance(table, instance); + if (auto instance = dyn_cast(op)) + return updateInstance(table, instance); return success(); } -Value InferModuleDomains::getUnderlyingDomain(Value value) { - assert(isa(value.getType())); - auto *term = getOptTermForDomain(value); - if (auto *val = llvm::dyn_cast_if_present(term)) - return val->value; - return nullptr; +/// After updating the port domain associations, walk the body of the module +/// to fix up any child instance modules. +LogicalResult updateModuleBody(const DomainTable &table, FModuleOp module) { + auto result = success(); + module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { + if (failed(updateOp(table, op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; } -Term *InferModuleDomains::getTermForDomain(Value value) { - assert(isa(value.getType())); - if (auto *term = getOptTermForDomain(value)) - return term; - auto *term = allocate(); - setTermForDomain(value, term); - return term; -} +/// Write the domain associations recorded in the domain table back to the IR. +LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, ModuleUpdateTable &updateTable, + FModuleOp op) { + if (failed(updateModulePorts(info, allocator, table, updateTable, op))) + return failure(); -Term *InferModuleDomains::getOptTermForDomain(Value value) const { - assert(isa(value.getType())); - auto it = termTable.find(value); - if (it == termTable.end()) - return nullptr; - return find(it->second); -} + if (failed(updateModuleBody(table, op))) + return failure(); -void InferModuleDomains::setTermForDomain(Value value, Term *term) { - assert(isa(value.getType())); - assert(term); - assert(!termTable.contains(value)); - termTable.insert({value, term}); + return success(); } -RowTerm *InferModuleDomains::getDomainAssociationAsRow(Value value) { - assert(isa(value.getType())); - auto *term = getOptDomainAssociation(value); - - // If the term is unknown, allocate a fresh row and set the association. - if (!term) { - auto *row = allocateRow(); - setDomainAssociation(value, row); - return row; - } +} // namespace - // If the term is already a row, return it. - if (auto *row = dyn_cast(term)) - return row; +//====-------------------------------------------------------------------------- +// Domain Inference: solve domains and check for correctness,then update the +// IR to reflect the solved domains. +//====-------------------------------------------------------------------------- - // Otherwise, unify the term with a fresh row of domains. - if (auto *var = dyn_cast(term)) { - auto *row = allocateRow(); - solve(var, row); - return row; - } +namespace { - assert(false && "unhandled term type"); - return nullptr; -} +/// Solve for domains and then write the domain associations back to the IR. +LogicalResult inferModule(const DomainInfo &info, + ModuleUpdateTable &updateTable, FModuleOp module) { + TermAllocator allocator; + DomainTable table; -Term *InferModuleDomains::getDomainAssociation(Value value) { - auto *term = getOptDomainAssociation(value); - if (term) - return term; - term = allocate(); - setDomainAssociation(value, term); - return term; -} + if (failed(processModule(info, allocator, table, updateTable, module))) + return failure(); -Term *InferModuleDomains::getOptDomainAssociation(Value value) const { - assert(isa(value.getType())); - auto it = associationTable.find(value); - if (it == associationTable.end()) - return nullptr; - return find(it->second); + return updateModule(info, allocator, table, updateTable, module); } -void InferModuleDomains::setDomainAssociation(Value value, Term *term) { - assert(isa(value.getType())); - assert(term); - term = find(term); - associationTable.insert({value, term}); - LLVM_DEBUG(llvm::errs() << " set domain association: " << value; - llvm::errs() << " -> " << term << "\n";); -} +} // namespace -RowTerm *InferModuleDomains::allocateRow() { - SmallVector elements; - elements.resize(globals.circuitInfo.getNumDomains()); - return allocateRow(elements); -} +//====-------------------------------------------------------------------------- +// Domain Checking: Solve for domains and check for correctness+completeness, +// without updating the IR. +//====-------------------------------------------------------------------------- -RowTerm *InferModuleDomains::allocateRow(ArrayRef elements) { - auto ds = allocateArray(elements); - return allocate(ds); -} +namespace { -template -T *InferModuleDomains::allocate(Args &&...args) { - static_assert(std::is_base_of_v, "T must be a term"); - return new (allocator) T(std::forward(args)...); -} +/// Check that a module has complete domain information for its ports. +LogicalResult checkPorts(const DomainInfo &info, FModuleLike module) { + auto numDomains = info.getNumDomains(); + auto domainInfo = module.getDomainInfoAttr(); + DenseMap typeIDTable; + for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { + auto type = module.getPortType(i); -ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { - auto size = elements.size(); - if (size == 0) - return {}; + if (isa(type)) { + auto typeID = info.getDomainTypeID(domainInfo, i); + typeIDTable[i] = typeID; + continue; + } - auto *result = allocator.Allocate(size); - llvm::uninitialized_copy(elements, result); - for (size_t i = 0; i < size; ++i) - if (!result[i]) - result[i] = allocate(); + if (auto baseType = type_dyn_cast(type)) { + SmallVector associations(numDomains); + auto domains = getPortDomainAssociation(domainInfo, i); + for (auto index : domains) { + auto typeID = typeIDTable[index.getUInt()]; + auto &entry = associations[typeID]; + if (entry && entry != index) { + auto domainName = info.getDomain(typeID).getNameAttr(); + auto portName = module.getPortNameAttr(i); + auto diag = emitError(module.getPortLocation(i)) + << "ambiguous " << domainName << " association for port " + << portName; - return ArrayRef(result, elements.size()); -} + auto d1Loc = module.getPortLocation(entry.getUInt()); + auto d1Name = module.getPortNameAttr(entry.getUInt()); + diag.attachNote(d1Loc) + << "associated with " << domainName << " port " << d1Name; -void InferModuleDomains::render(Diagnostic &out, Term *term) const { - VariableIDTable idTable; - render(out, idTable, term); -} + auto d2Loc = module.getPortLocation(index.getUInt()); + auto d2Name = module.getPortNameAttr(index.getUInt()); + diag.attachNote(d2Loc) + << "associated with " << domainName << " port " << d2Name; + } + entry = index; + } -// NOLINTNEXTLINE(misc-no-recursion) -void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, - Term *term) const { - term = find(term); - if (auto *var = dyn_cast(term)) { - out << "?" << idTable.get(var); - return; - } - if (auto *val = dyn_cast(term)) { - auto value = val->value; - auto [name, rooted] = getFieldName(FieldRef(value, 0), false); - out << name; - return; - } - if (auto *row = dyn_cast(term)) { - bool first = true; - out << "["; - for (size_t i = 0, e = globals.circuitInfo.getNumDomains(); i < e; ++i) { - auto domainOp = globals.circuitInfo.getDomain(i); - if (!first) { - out << ", "; - first = false; + for (size_t typeID = 0; typeID < numDomains; ++typeID) { + auto association = associations[typeID]; + if (!association) { + auto domainName = info.getDomain(typeID).getNameAttr(); + auto portName = module.getPortNameAttr(i); + return emitError(module.getPortLocation(i)) + << "missing " << domainName << " association for port " + << portName; + } } - out << domainOp.getName() << ": "; - render(out, idTable, row->elements[i]); } - out << "]"; - return; } -} - -template -void InferModuleDomains::emitPortDomainCrossingError(T op, size_t i, - size_t domainTypeID, - Term *term1, - Term *term2) const { - VariableIDTable idTable; - - auto portName = op.getPortNameAttr(i); - auto portLoc = op.getPortLocation(i); - auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); - auto domainName = domainDecl.getNameAttr(); - auto diag = emitError(portLoc); - diag << "illegal " << domainName << " crossing in port " << portName; + return success(); +} - auto ¬e1 = diag.attachNote(); - note1 << "1st instance: "; - render(note1, idTable, term1); +/// Check that a module's ports are fully annotated, and check that there are no +/// domain crossing errors in the module's body, without modify the IR. +LogicalResult checkModule(const DomainInfo &info, + ModuleUpdateTable &updateTable, FModuleOp module) { + if (failed(checkPorts(info, module))) + return failure(); - auto ¬e2 = diag.attachNote(); - note2 << "2nd instance: "; - render(note2, idTable, term2); + DomainTable table; + TermAllocator allocator; + return processModule(info, allocator, table, updateTable, module); } -template -void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { - auto name = op.getPortNameAttr(i); - auto diag = emitError(op->getLoc()); - auto info = op.getDomainInfo(); - diag << "unable to infer value for domain port " << name; - for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { - if (auto assocs = dyn_cast(info[j])) { - for (auto assoc : assocs) { - if (i == cast(assoc).getValue()) { - auto name = op.getPortNameAttr(j); - auto loc = op.getPortLocation(j); - diag.attachNote(loc) << "associated with hardware port " << name; - break; - } - } - } - } +/// Check that an extmodule's ports are fully annotated. +LogicalResult checkModule(const DomainInfo &info, FExtModuleOp module) { + return checkPorts(info, module); } -static LogicalResult inferModuleDomains(GlobalState &globals, - FModuleOp module) { - return InferModuleDomains::run(globals, module); -} +} // namespace //===--------------------------------------------------------------------------- // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- -static LogicalResult runOnModuleLike(InferDomainsMode mode, - GlobalState &globals, Operation *op) { +namespace { + +LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, + ModuleUpdateTable &updateTable, Operation *op) { + llvm::errs() << "********\n"; + llvm::errs() << *op << "\n"; if (auto module = dyn_cast(op)) { if (shouldInfer(module, mode)) - return inferModuleDomains(globals, module); - return checkModuleDomains(globals, module); + return inferModule(info, updateTable, module); + return checkModule(info, updateTable, module); } - if (auto extModule = dyn_cast(op)) { - return checkModuleDomains(globals, extModule); - } + if (auto extModule = dyn_cast(op)) + return checkModule(info, extModule); return success(); } -namespace { struct InferDomainsPass : public circt::firrtl::impl::InferDomainsBase { using InferDomainsBase::InferDomainsBase; - void runOnOperation() override; + void runOnOperation() override { + CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); + auto circuit = getOperation(); + auto &instanceGraph = getAnalysis(); + DomainInfo info(circuit); + DenseSet visited; + ModuleUpdateTable updateTable; + auto result = instanceGraph.walkPostOrder([&](auto &node) { + return runOnModuleLike(mode, info, updateTable, node.getModule()); + }); + if (failed(result)) + signalPassFailure(); + } }; -} // namespace -void InferDomainsPass::runOnOperation() { - CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); - auto circuit = getOperation(); - auto &instanceGraph = getAnalysis(); - GlobalState globals(circuit); - DenseSet visited; - auto result = instanceGraph.walkPostOrder([&](auto &node) { - return runOnModuleLike(mode, globals, node.getModule()); - }); - if (failed(result)) - signalPassFailure(); -} +} // namespace From fb90ba27f7419e1cd06a2dfccc9f32e766300a86 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 25 Nov 2025 14:33:43 -0500 Subject: [PATCH 12/14] More infer/check stuff --- .../FIRRTL/Transforms/InferDomains.cpp | 235 +++++++----------- .../FIRRTL/infer-domains-check-errors.mlir | 61 ++++- test/Dialect/FIRRTL/infer-domains-errors.mlir | 4 +- 3 files changed, 141 insertions(+), 159 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 1617610b6239..790ce63617f3 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -37,12 +37,12 @@ namespace firrtl { using namespace circt; using namespace firrtl; +namespace { + //====-------------------------------------------------------------------------- // Helpers. //====-------------------------------------------------------------------------- -namespace { - using PortInsertions = SmallVector>; template @@ -81,14 +81,19 @@ bool isPort(FModuleOp module, Value value) { return isPort(module, arg); } -} // namespace +/// Returns true if the value is driven by a connect op. +bool isDriven(Value port) { + for (auto *user : port.getUsers()) + if (auto connect = dyn_cast(user)) + if (connect.getDest() == port) + return true; + return false; +} //====-------------------------------------------------------------------------- // Global State. //====-------------------------------------------------------------------------- -namespace { - /// Each domain type declared in the circuit is assigned a type-id, based on the /// order of declaration. Domain associations for hardware values are /// represented as a list, or row, of domains. The domains in a row are ordered @@ -178,14 +183,10 @@ T fixInstancePorts(T op, const ModuleUpdateInfo &update) { return clone; } -} // namespace - //====-------------------------------------------------------------------------- // Terms: Syntax for unifying domain and domain-rows. //====-------------------------------------------------------------------------- -namespace { - /// The different sorts of terms in the unification engine. enum class TermKind { Variable, @@ -285,11 +286,6 @@ void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, } } -void render(const DomainInfo &info, Diagnostic &out, Term *term) { - VariableIDTable idTable; - render(info, out, idTable, term); -} - #ifndef NDEBUG raw_ostream &dump(llvm::raw_ostream &out, const Term *term); @@ -439,14 +435,10 @@ class TermAllocator { llvm::BumpPtrAllocator allocator; }; -} // namespace - //====-------------------------------------------------------------------------- // DomainTable: A mapping from IR to terms. //====-------------------------------------------------------------------------- -namespace { - /// Tracks domain infomation for IR values. class DomainTable { public: @@ -519,14 +511,10 @@ class DomainTable { DenseMap associationTable; }; -} // namespace - //====-------------------------------------------------------------------------- // Module processing: solve for the domain associations of hardware. //====-------------------------------------------------------------------------- -namespace { - /// Get the corresponding term for a domain in the IR. If we don't know what the /// term is, then map the domain in the IR to a variable term. Term *getTermForDomain(TermAllocator &allocator, DomainTable &table, @@ -539,23 +527,6 @@ Term *getTermForDomain(TermAllocator &allocator, DomainTable &table, return term; } -/// Get the associated domain row, forced to be at least a row. -/// Get the row of domains that a hardware value in the IR is associated with. -/// If we don't know what the row is, associate the hardware value in the IR to -/// a variable term. -/// For a hardware value, get the term which represents the row of associated -/// domains. If no mapping has been defined, allocate a variable to stand for -/// the row of domains. -Term *getDomainAssociation(TermAllocator &allocator, DomainTable &table, - Value value) { - auto *term = table.getOptDomainAssociation(value); - if (term) - return term; - term = allocator.allocVar(); - table.setDomainAssociation(value, term); - return term; -} - /// Get the row of domains that a hardware value in the IR is associated with. /// The returned term is forced to be at least a row. RowTerm *getDomainAssociationAsRow(const DomainInfo &info, @@ -616,7 +587,7 @@ void emitDomainPortInferenceError(T op, size_t i) { auto name = op.getPortNameAttr(i); auto diag = emitError(op->getLoc()); auto info = op.getDomainInfo(); - diag << "unable to infer value for domain port " << name; + diag << "unable to infer value for undriven domain port " << name; for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { if (auto assocs = dyn_cast(info[j])) { for (auto assoc : assocs) { @@ -631,6 +602,18 @@ void emitDomainPortInferenceError(T op, size_t i) { } } +template +void emitAmbiguousPortDomainAssociation(T op, size_t i) {} + +template +void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, + size_t typeID, size_t i) { + auto domainName = info.getDomain(typeID).getNameAttr(); + auto portName = op.getPortNameAttr(i); + emitError(op.getPortLocation(i)) + << "missing " << domainName << " association for port " << portName; +} + /// Unify the associated domain rows of two terms. LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, @@ -949,45 +932,6 @@ SmallVector copyPortDomainAssociations(const DomainInfo &info, return result; } -template -LogicalResult updateInstanceDomainAssociations(const DomainTable &table, T op) { - auto *context = op.getContext(); - OpBuilder builder(context); - builder.setInsertionPointAfter(op); - auto numPorts = op->getNumResults(); - for (size_t i = 0; i < numPorts; ++i) { - auto port = op.getResult(i); - auto type = port.getType(); - auto direction = op.getPortDirection(i); - if (isa(type)) { - if (direction == Direction::In) { - bool driven = false; - for (auto *user : port.getUsers()) { - if (auto connect = dyn_cast(user)) { - if (connect.getDest() == port) { - driven = true; - break; - } - } - } - if (!driven) { - auto *term = table.getTermForDomain(port); - term = find(term); - if (auto *val = dyn_cast(term)) { - auto loc = port.getLoc(); - auto value = val->value; - DomainDefineOp::create(builder, loc, port, value); - } else { - emitDomainPortInferenceError(op, i); - return failure(); - } - } - } - } - } - return success(); -} - /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. @@ -1123,32 +1067,19 @@ LogicalResult updateModuleDomainInfo(const DomainInfo &info, // If the port is an output domain, we may need to drive the output with // a value. If we don't know what value to drive to the port, error. if (isa(type)) { - if (module.getPortDirection(i) == Direction::Out) { - bool driven = false; - for (auto *user : port.getUsers()) { - if (auto connect = dyn_cast(user)) { - if (connect.getDest() == port) { - driven = true; - break; - } - } - } - + // If the output port is not driven, drive it. + if (module.getPortDirection(i) == Direction::Out && !isDriven(port)) { // Get the underlying value of the output port. - auto *term = table.getTermForDomain(port); - term = find(term); - auto *val = dyn_cast(term); + auto *term = table.getOptTermForDomain(port); + auto *val = llvm::dyn_cast_if_present(term); if (!val) { emitDomainPortInferenceError(module, i); return failure(); } - // If the output port is not driven, drive it. - if (!driven) { - auto loc = port.getLoc(); - auto value = val->value; - DomainDefineOp::create(builder, loc, port, value); - } + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); } newModuleDomainInfo[i] = oldModuleDomainInfo[i]; @@ -1241,37 +1172,28 @@ LogicalResult updateModulePorts(const DomainInfo &info, template LogicalResult updateInstance(const DomainTable &table, T op) { - OpBuilder builder(op); + auto *context = op.getContext(); + OpBuilder builder(context); builder.setInsertionPointAfter(op); auto numPorts = op->getNumResults(); for (size_t i = 0; i < numPorts; ++i) { auto port = op.getResult(i); auto type = port.getType(); auto direction = op.getPortDirection(i); - if (isa(type)) { - if (direction == Direction::In) { - bool driven = false; - for (auto *user : port.getUsers()) { - if (auto connect = dyn_cast(user)) { - if (connect.getDest() == port) { - driven = true; - break; - } - } - } - if (!driven) { - auto *term = table.getTermForDomain(port); - term = find(term); - if (auto *val = dyn_cast(term)) { - auto loc = port.getLoc(); - auto value = val->value; - DomainDefineOp::create(builder, loc, port, value); - } else { - emitDomainPortInferenceError(op, i); - return failure(); - } - } + // If the port is an input domain, we may need to drive the input with + // a value. If we don't know what value to drive to the port, error. + if (isa(type) && direction == Direction::In && + !isDriven(port)) { + auto *term = table.getOptTermForDomain(port); + auto *val = llvm::dyn_cast_if_present(term); + if (!val) { + emitDomainPortInferenceError(op, i); + return failure(); } + + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); } } return success(); @@ -1312,38 +1234,30 @@ LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, return success(); } -} // namespace - //====-------------------------------------------------------------------------- // Domain Inference: solve domains and check for correctness,then update the // IR to reflect the solved domains. //====-------------------------------------------------------------------------- -namespace { - /// Solve for domains and then write the domain associations back to the IR. LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp module) { TermAllocator allocator; DomainTable table; - if (failed(processModule(info, allocator, table, updateTable, module))) return failure(); return updateModule(info, allocator, table, updateTable, module); } -} // namespace - //====-------------------------------------------------------------------------- -// Domain Checking: Solve for domains and check for correctness+completeness, -// without updating the IR. +// Domain Inference+Checking: Check that the interface of the module is fully +// annotated, before proceeding to run domain inference on the body of the +// module. //====-------------------------------------------------------------------------- -namespace { - /// Check that a module has complete domain information for its ports. -LogicalResult checkPorts(const DomainInfo &info, FModuleLike module) { +LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike module) { auto numDomains = info.getNumDomains(); auto domainInfo = module.getDomainInfoAttr(); DenseMap typeIDTable; @@ -1385,11 +1299,8 @@ LogicalResult checkPorts(const DomainInfo &info, FModuleLike module) { for (size_t typeID = 0; typeID < numDomains; ++typeID) { auto association = associations[typeID]; if (!association) { - auto domainName = info.getDomain(typeID).getNameAttr(); - auto portName = module.getPortNameAttr(i); - return emitError(module.getPortLocation(i)) - << "missing " << domainName << " association for port " - << portName; + emitMissingPortDomainAssociationError(info, module, typeID, i); + return failure(); } } } @@ -1398,39 +1309,59 @@ LogicalResult checkPorts(const DomainInfo &info, FModuleLike module) { return success(); } -/// Check that a module's ports are fully annotated, and check that there are no -/// domain crossing errors in the module's body, without modify the IR. +LogicalResult checkModuleBody(const DomainInfo &info, + ModuleUpdateTable &updateTable, + FModuleOp module) { + return success(); +} + +/// Check that a module's ports are fully annotated, before performing domain +/// inference on the module. LogicalResult checkModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp module) { - if (failed(checkPorts(info, module))) + if (failed(checkModulePorts(info, module))) return failure(); - DomainTable table; - TermAllocator allocator; - return processModule(info, allocator, table, updateTable, module); + return checkModuleBody(info, updateTable, module); } /// Check that an extmodule's ports are fully annotated. LogicalResult checkModule(const DomainInfo &info, FExtModuleOp module) { - return checkPorts(info, module); + return checkModulePorts(info, module); } -} // namespace +//====-------------------------------------------------------------------------- +// Hybrid Mode: Check the interface, then infer. +//====-------------------------------------------------------------------------- + +/// Check that a module's ports are fully annotated, before performing domain +/// inference on the module. We use this when private module interfaces are +/// inferred but public module interfaces are checked. +LogicalResult checkAndInferModule(const DomainInfo &info, + ModuleUpdateTable &updateTable, + FModuleOp module) { + if (failed(checkModulePorts(info, module))) + return failure(); + + return inferModule(info, updateTable, module); +} //===--------------------------------------------------------------------------- // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- -namespace { - LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op) { llvm::errs() << "********\n"; llvm::errs() << *op << "\n"; + if (auto module = dyn_cast(op)) { - if (shouldInfer(module, mode)) + if (mode == InferDomainsMode::Check) + return checkModule(info, updateTable, module); + if (mode == InferDomainsMode::InferAll || module.isPrivate()) return inferModule(info, updateTable, module); - return checkModule(info, updateTable, module); + + return checkAndInferModule(info, updateTable, module); } if (auto extModule = dyn_cast(op)) diff --git a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir index 26c6b6848924..b90dced214e2 100644 --- a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir @@ -1,13 +1,64 @@ // RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=check}))' %s --verify-diagnostics --split-input-file -// CHECK-LABEL: IncompleteDomainInformation -firrtl.circuit "IncompleteDomainInformation" { +// In check-mode, we check that the interface of public modules is fully annotated +// with domain inference +// CHECK-LABEL: MissingDomain +firrtl.circuit "MissingDomain" { firrtl.domain @ClockDomain - firrtl.module private @Foo( + firrtl.module @MissingDomain( // expected-error @below {{missing "ClockDomain" association for port "x"}} - in %x: !firrtl.uint<1> + in %x: !firrtl.uint<1> ) {} +} + +// CHECK-LABEL: MissingSecondDomain +firrtl.circuit "MissingSecondDomain" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @MissingSecondDomain( + in %c : !firrtl.domain of @ClockDomain, + // expected-error @below {{missing "PowerDomain" association for port "x"}} + in %x : !firrtl.uint<1> domains [%c] + ) {} +} + +// CHECK-LABEL: UndrivenOutputDomain +firrtl.circuit "UndrivenOutputDomain" { + firrtl.domain @ClockDomain + + firrtl.module @UndrivenOutputDomain( + // expected-error @below {{unable to infer value for undriven domain port "c"}} + out %c : !firrtl.domain of @ClockDomain + ) {} +} + +// CHECK-LABEL: UndrivenInstanceDomainPort +firrtl.circuit "UndrivenInstanceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %foo_c = firrtl.instance foo @Foo(in c : !firrtl.domain of @ClockDomain) + } +} + +// CHECK-LABEL: UndrivenInstanceChoiceDomainPort +firrtl.circuit "UndrivenInstanceChoiceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + } + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + firrtl.extmodule @Bar(in c : !firrtl.domain of @ClockDomain) - firrtl.module @IncompleteDomainInformation() {} + firrtl.module @UndrivenInstanceChoiceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %inst_c = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar } (in c : !firrtl.domain of @ClockDomain) + } } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index 47a018ff6dc8..dc3fa6ff5596 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -58,7 +58,7 @@ firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> - // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} // expected-note @below {{associated with hardware port "i"}} %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> @@ -75,7 +75,7 @@ firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> - // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} // expected-note @below {{associated with hardware port "i"}} %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> From e47849792f614b7fc26b733d6dba12efd4ba25d1 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 25 Nov 2025 15:58:54 -0500 Subject: [PATCH 13/14] More changes --- .../FIRRTL/Transforms/InferDomains.cpp | 82 ++++- .../FIRRTL/infer-domains-check-errors.mlir | 30 +- test/Dialect/FIRRTL/infer-domains-errors.mlir | 195 ----------- test/Dialect/FIRRTL/infer-domains.mlir | 323 ------------------ 4 files changed, 94 insertions(+), 536 deletions(-) delete mode 100644 test/Dialect/FIRRTL/infer-domains-errors.mlir delete mode 100644 test/Dialect/FIRRTL/infer-domains.mlir diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 790ce63617f3..05899a7ab8c6 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -488,7 +488,6 @@ class DomainTable { /// For a hardware value, get the term which represents the row of associated /// domains. Term *getDomainAssociation(Value value) const { - llvm::errs() << "value = " << value << "\n"; auto *term = getOptDomainAssociation(value); assert(term); return term; @@ -1256,7 +1255,7 @@ LogicalResult inferModule(const DomainInfo &info, // module. //====-------------------------------------------------------------------------- -/// Check that a module has complete domain information for its ports. +/// Check that a module's hardware ports have complete domain associations. LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike module) { auto numDomains = info.getNumDomains(); auto domainInfo = module.getDomainInfoAttr(); @@ -1309,20 +1308,79 @@ LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike module) { return success(); } -LogicalResult checkModuleBody(const DomainInfo &info, - ModuleUpdateTable &updateTable, - FModuleOp module) { +/// Check that output domain ports are driven. +LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, + FModuleOp module) { + for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type) || + module.getPortDirection(i) != Direction::Out || isDriven(port)) + continue; + + auto name = module.getPortNameAttr(i); + emitError(module.getPortLocation(i)) << "undriven domain port " << name; + return failure(); + } + + return success(); +} + +/// Check that the input domain ports are driven. +template +LogicalResult checkInstanceDomainPortDrivers(T op) { + for (size_t i = 0, e = op.getNumResults(); i < e; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + if (!isa(type) || op.getPortDirection(i) != Direction::In || + isDriven(port)) + continue; + + auto name = op.getPortNameAttr(i); + emitError(op.getPortLocation(i)) << "undriven domain port " << name; + return failure(); + } + + return success(); +} + +LogicalResult checkOp(Operation *op) { + if (auto inst = dyn_cast(op)) + return checkInstanceDomainPortDrivers(inst); + if (auto inst = dyn_cast(op)) + return checkInstanceDomainPortDrivers(inst); return success(); } +/// Check that instances under this module have driven domain input ports. +LogicalResult checkModuleBody(FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(checkOp(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + /// Check that a module's ports are fully annotated, before performing domain /// inference on the module. -LogicalResult checkModule(const DomainInfo &info, - ModuleUpdateTable &updateTable, FModuleOp module) { +LogicalResult checkModule(const DomainInfo &info, FModuleOp module) { if (failed(checkModulePorts(info, module))) return failure(); - return checkModuleBody(info, updateTable, module); + if (failed(checkModuleDomainPortDrivers(info, module))) + return failure(); + + if (failed(checkModuleBody(module))) + return failure(); + + TermAllocator allocator; + DomainTable table; + ModuleUpdateTable updateTable; + return processModule(info, allocator, table, updateTable, module); } /// Check that an extmodule's ports are fully annotated. @@ -1352,15 +1410,15 @@ LogicalResult checkAndInferModule(const DomainInfo &info, LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op) { - llvm::errs() << "********\n"; - llvm::errs() << *op << "\n"; + // llvm::errs() << "********\n"; + // llvm::errs() << *op << "\n"; if (auto module = dyn_cast(op)) { if (mode == InferDomainsMode::Check) - return checkModule(info, updateTable, module); + return checkModule(info, module); if (mode == InferDomainsMode::InferAll || module.isPrivate()) return inferModule(info, updateTable, module); - + return checkAndInferModule(info, updateTable, module); } diff --git a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir index b90dced214e2..59782b1b7dd0 100644 --- a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir @@ -1,7 +1,8 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=check}))' %s --verify-diagnostics --split-input-file +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=check}))' %s --verify-diagnostics + +// in "check" mode, infer-domains will require that all ops are fully annotated +// with domains. No inference is run. -// In check-mode, we check that the interface of public modules is fully annotated -// with domain inference // CHECK-LABEL: MissingDomain firrtl.circuit "MissingDomain" { firrtl.domain @ClockDomain @@ -29,7 +30,7 @@ firrtl.circuit "UndrivenOutputDomain" { firrtl.domain @ClockDomain firrtl.module @UndrivenOutputDomain( - // expected-error @below {{unable to infer value for undriven domain port "c"}} + // expected-error @below {{undriven domain port "c"}} out %c : !firrtl.domain of @ClockDomain ) {} } @@ -41,7 +42,7 @@ firrtl.circuit "UndrivenInstanceDomainPort" { firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) firrtl.module @UndrivenInstanceDomainPort() { - // expected-error @below {{unable to infer value for undriven domain port "c"}} + // expected-error @below {{undriven domain port "c"}} %foo_c = firrtl.instance foo @Foo(in c : !firrtl.domain of @ClockDomain) } } @@ -58,7 +59,24 @@ firrtl.circuit "UndrivenInstanceChoiceDomainPort" { firrtl.extmodule @Bar(in c : !firrtl.domain of @ClockDomain) firrtl.module @UndrivenInstanceChoiceDomainPort() { - // expected-error @below {{unable to infer value for undriven domain port "c"}} + // expected-error @below {{undriven domain port "c"}} %inst_c = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar } (in c : !firrtl.domain of @ClockDomain) } } + +// Test that domain crossing errors are still caught when in check-only mode. +// Catching this involves processing the module without writing back to the IR. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir deleted file mode 100644 index dc3fa6ff5596..000000000000 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ /dev/null @@ -1,195 +0,0 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s --verify-diagnostics --split-input-file - -// Port annotated with same domain type twice. -firrtl.circuit "DomainCrossOnPort" { - firrtl.domain @ClockDomain - firrtl.module @DomainCrossOnPort( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} - // expected-note @below {{1st instance: A}} - // expected-note @below {{2nd instance: B}} - in %p: !firrtl.uint<1> domains [%A, %B] - ) {} -} - -// ----- - -// Illegal domain crossing via connect op. -firrtl.circuit "IllegalDomainCrossing" { - firrtl.domain @ClockDomain - firrtl.module @IllegalDomainCrossing( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} - in %a: !firrtl.uint<1> domains [%A], - // expected-note @below {{1st operand has domains: [ClockDomain: B]}} - out %b: !firrtl.uint<1> domains [%B] - ) { - // expected-error @below {{illegal domain crossing in operation}} - firrtl.connect %b, %a : !firrtl.uint<1> - } -} - -// ----- - -// Illegal domain crossing at matchingconnect op. -firrtl.circuit "IllegalDomainCrossing" { - firrtl.domain @ClockDomain - firrtl.module @IllegalDomainCrossing( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} - in %a: !firrtl.uint<1> domains [%A], - // expected-note @below {{1st operand has domains: [ClockDomain: B]}} - out %b: !firrtl.uint<1> domains [%B] - ) { - // expected-error @below {{illegal domain crossing in operation}} - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - } -} - -// ----- - -// Unable to infer domain of port, when port is driven by constant. -firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { - firrtl.domain @ClockDomain - firrtl.module @Foo(in %i: !firrtl.uint<1>) {} - - firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { - %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> - // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} - // expected-note @below {{associated with hardware port "i"}} - %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) - firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> - } -} - -// ----- - -// Unable to infer domain of port, when port is driven by arithmetic on constant. -firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { - firrtl.domain @ClockDomain - firrtl.module @Foo(in %i: !firrtl.uint<2>) {} - - firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { - %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> - %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> - // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} - // expected-note @below {{associated with hardware port "i"}} - %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) - firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> - } -} - -// ----- - -// Incomplete extmodule domain information. - -firrtl.circuit "Top" { - firrtl.domain @ClockDomain - - // expected-error @below {{missing "ClockDomain" association for port "i"}} - firrtl.extmodule @Top(in i: !firrtl.uint<1>) -} - -// ----- - -// Conflicting extmodule domain information. - -firrtl.circuit "Top" { - firrtl.domain @ClockDomain - - firrtl.extmodule @Top( - // expected-note @below {{associated with "ClockDomain" port "D1"}} - in D1 : !firrtl.domain of @ClockDomain, - // expected-note @below {{associated with "ClockDomain" port "D2"}} - in D2 : !firrtl.domain of @ClockDomain, - // expected-error @below {{ambiguous "ClockDomain" association for port "i"}} - in i: !firrtl.uint<1> domains [D1, D2] - ) -} - -// ----- - -// Domain exported multiple times. Which do we choose? - -firrtl.circuit "DoubleExportOfDomain" { - firrtl.domain @ClockDomain - - firrtl.module @DoubleExportOfDomain( - // expected-note @below {{candidate association "DI"}} - in %DI : !firrtl.domain of @ClockDomain, - // expected-note @below {{candidate association "DO"}} - out %DO : !firrtl.domain of @ClockDomain, - in %i : !firrtl.uint<1> domains [%DO], - // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} - out %o : !firrtl.uint<1> domains [] - ) { - // DI and DO are aliases - firrtl.domain.define %DO, %DI - - // o is on same domain as i - firrtl.matchingconnect %o, %i : !firrtl.uint<1> - } -} - -// ----- - -// Domain exported multiple times, this time with two outputs. - -firrtl.circuit "DoubleExportOfDomain" { - firrtl.domain @ClockDomain - - firrtl.extmodule @Generator(out D: !firrtl.domain of @ClockDomain) - - firrtl.module @DoubleExportOfDomain( - // expected-note @below {{candidate association "D1"}} - out %D1 : !firrtl.domain of @ClockDomain, - // expected-note @below {{candidate association "D2"}} - out %D2 : !firrtl.domain of @ClockDomain, - in %i : !firrtl.uint<1> domains [%D1], - // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} - out %o : !firrtl.uint<1> domains [] - ) { - %gen_D = firrtl.instance gen @Generator(out D: !firrtl.domain of @ClockDomain) - // DI and DO are aliases - firrtl.domain.define %D1, %gen_D - firrtl.domain.define %D2, %gen_D - - // o is on same domain as i - firrtl.matchingconnect %o, %i : !firrtl.uint<1> - } -} - -// ----- - -// InstanceChoice: Each module has different domains inferred. -// TODO: this just relies on the op-verifier for instance choice ops. - -firrtl.circuit "ConflictingInstanceChoiceDomains" { - firrtl.domain @ClockDomain - - firrtl.option @Option { - firrtl.option_case @X - firrtl.option_case @Y - } - - // Foo's "out" port takes on the domains of "in1". - firrtl.module @Foo(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { - firrtl.connect %out, %in1 : !firrtl.uint<1> - } - - // Bar's "out" port takes on the domains of "in2". - // expected-note @below {{original module declared here}} - firrtl.module @Bar(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { - firrtl.connect %out, %in2 : !firrtl.uint<1> - } - - firrtl.module @ConflictingInstanceChoiceDomains(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>) { - // expected-error @below {{'firrtl.instance_choice' op domain info for "out" must be [2 : ui32], but got [0 : ui32]}} - %inst_in1, %inst_in2, %inst_out = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Foo, @Y -> @Bar } (in in1: !firrtl.uint<1>, in in2: !firrtl.uint<1>, out out: !firrtl.uint<1>) - firrtl.connect %inst_in1, %in1 : !firrtl.uint<1>, !firrtl.uint<1> - firrtl.connect %inst_in2, %in2 : !firrtl.uint<1>, !firrtl.uint<1> - } -} diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir deleted file mode 100644 index 1b7d465d34c7..000000000000 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ /dev/null @@ -1,323 +0,0 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s | FileCheck %s - -// Legal domain usage - no crossing. -// CHECK-LABEL: firrtl.circuit "LegalDomains" -firrtl.circuit "LegalDomains" { - firrtl.domain @ClockDomain - firrtl.module @LegalDomains( - in %A: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A], - out %b: !firrtl.uint<1> domains [%A] - ) { - // Connecting within the same domain is legal. - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - } -} - -// Domain inference through connections. -// CHECK-LABEL: firrtl.circuit "DomainInference" -firrtl.circuit "DomainInference" { - firrtl.domain @ClockDomain - firrtl.module @DomainInference( - in %A: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A], - // CHECK: out %c: !firrtl.uint<1> domains [%A] - out %c: !firrtl.uint<1> - ) { - %b = firrtl.wire : !firrtl.uint<1> // No explicit domain - - // This should infer that %b is in domain %A. - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - - // This should be legal since %b is now inferred to be in domain %A. - firrtl.matchingconnect %c, %b : !firrtl.uint<1> - } -} - -// Unsafe domain cast -// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" -firrtl.circuit "UnsafeDomainCast" { - firrtl.domain @ClockDomain - firrtl.module @UnsafeDomainCast( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A], - out %c: !firrtl.uint<1> domains [%B] - ) { - // Unsafe cast from domain A to domain B. - %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> - - // This should be legal since we explicitly cast. - firrtl.matchingconnect %c, %b : !firrtl.uint<1> - } -} - -// Domain sequence matching. -// CHECK-LABEL: firrtl.circuit "LegalSequences" -firrtl.circuit "LegalSequences" { - firrtl.domain @ClockDomain - firrtl.domain @PowerDomain - firrtl.module @LegalSequences( - in %C: !firrtl.domain of @ClockDomain, - in %P: !firrtl.domain of @PowerDomain, - in %a: !firrtl.uint<1> domains [%C, %P], - out %b: !firrtl.uint<1> domains [%C, %P] - ) { - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - } -} - -// Domain sequence order equivalence - should be legal -// CHECK-LABEL: SequenceOrderEquivalence -firrtl.circuit "SequenceOrderEquivalence" { - firrtl.domain @ClockDomain - firrtl.domain @PowerDomain - firrtl.module @SequenceOrderEquivalence( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @PowerDomain, - in %a: !firrtl.uint<1> domains [%A, %B], - out %b: !firrtl.uint<1> domains [%B, %A] - ) { - // This should be legal since domain order doesn't matter in canonical representation - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - } -} - -// Domain sequence inference -// CHECK-LABEL: SequenceInference -firrtl.circuit "SequenceInference" { - firrtl.domain @ClockDomain - firrtl.domain @PowerDomain - firrtl.module @SequenceInference( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @PowerDomain, - in %a: !firrtl.uint<1> domains [%A, %B], - out %d: !firrtl.uint<1> - ) { - %c = firrtl.wire : !firrtl.uint<1> - - // %c should infer domain sequence [%A, %B] - firrtl.matchingconnect %c, %a : !firrtl.uint<1> - - // This should be legal since %c has inferred [%A, %B] - firrtl.matchingconnect %d, %c : !firrtl.uint<1> - } -} - -// Domain duplicate equivalence - should be legal. -// CHECK-LABEL: DuplicateDomainEquivalence -firrtl.circuit "DuplicateDomainEquivalence" { - firrtl.domain @ClockDomain - firrtl.module @DuplicateDomainEquivalence( - in %A: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A, %A], - out %b: !firrtl.uint<1> domains [%A] - ) { - // This should be legal since duplicate domains are canonicalized. - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - } -} - -// Unsafe domain cast with sequences -// CHECK-LABEL: UnsafeSequenceCast -firrtl.circuit "UnsafeSequenceCast" { - firrtl.domain @ClockDomain - firrtl.domain @PowerDomain - - firrtl.module @UnsafeSequenceCast( - in %C1: !firrtl.domain of @ClockDomain, - in %C2: !firrtl.domain of @ClockDomain, - in %P1: !firrtl.domain of @PowerDomain, - in %i: !firrtl.uint<1> domains [%C1, %P1], - out %o: !firrtl.uint<1> domains [%C2, %P1] - ) { - %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> - firrtl.matchingconnect %o, %0 : !firrtl.uint<1> - } -} - -// Different port types domain inference. -// CHECK-LABEL: DifferentPortTypes -firrtl.circuit "DifferentPortTypes" { - firrtl.domain @ClockDomain - firrtl.module @DifferentPortTypes( - in %A: !firrtl.domain of @ClockDomain, - in %uint_input: !firrtl.uint<8> domains [%A], - in %sint_input: !firrtl.sint<4> domains [%A], - out %uint_output: !firrtl.uint<8>, - out %sint_output: !firrtl.sint<4> - ) { - firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> - firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> - } -} - -// Domain inference through wires. -// CHECK-LABEL: DomainInferenceThroughWires -firrtl.circuit "DomainInferenceThroughWires" { - firrtl.domain @ClockDomain - firrtl.module @DomainInferenceThroughWires( - in %A: !firrtl.domain of @ClockDomain, - in %input: !firrtl.uint<1> domains [%A], - // CHECK: out %output: !firrtl.uint<1> domains [%A] - out %output: !firrtl.uint<1> - ) { - %wire1 = firrtl.wire : !firrtl.uint<1> - %wire2 = firrtl.wire : !firrtl.uint<1> - - firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> - firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> - firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> - } -} - -// Export: add output domain port for domain created internally. -// CHECK-LABEL: ExportDomain -firrtl.circuit "ExportDomain" { - firrtl.domain @ClockDomain - - firrtl.extmodule @Foo( - out A: !firrtl.domain of @ClockDomain, - out o: !firrtl.uint<1> domains [A] - ) - - firrtl.module @ExportDomain( - // CHECK: out %ClockDomain: !firrtl.domain of @ClockDomain - // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] - out %o: !firrtl.uint<1> - ) { - %foo_A, %foo_o = firrtl.instance foo @Foo( - out A: !firrtl.domain of @ClockDomain, - out o: !firrtl.uint<1> domains [A] - ) - firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> - // CHECK: firrtl.domain.define %ClockDomain, %foo_A - } -} - -// Export: Reuse already-exported domain. -// CHECK-LABEL: ReuseExportedDomain -firrtl.circuit "ReuseExportedDomain" { - firrtl.domain @ClockDomain - - firrtl.extmodule @Foo( - out A: !firrtl.domain of @ClockDomain, - out o: !firrtl.uint<1> domains [A] - ) - - firrtl.module @ReuseExportedDomain( - out %A: !firrtl.domain of @ClockDomain, - // CHECK: out %o: !firrtl.uint<1> domains [%A] - out %o: !firrtl.uint<1> - ) { - %foo_A, %foo_o = firrtl.instance foo @Foo( - out A: !firrtl.domain of @ClockDomain, - out o: !firrtl.uint<1> domains [A] - ) - firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> - firrtl.domain.define %A, %foo_A - } -} - -// CHECK-LABEL: RegisterInference -firrtl.circuit "RegisterInference" { - firrtl.domain @ClockDomain - firrtl.module @RegisterInference( - in %A: !firrtl.domain of @ClockDomain, - in %clock: !firrtl.clock domains [%A], - // CHECK: in %d: !firrtl.uint<1> domains [%A] - in %d: !firrtl.uint<1>, - // CHECK: out %q: !firrtl.uint<1> domains [%A] - out %q: !firrtl.uint<1> - ) { - %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> - firrtl.matchingconnect %r, %d : !firrtl.uint<1> - firrtl.matchingconnect %q, %r : !firrtl.uint<1> - } -} - -// CHECK-LABEL: InstanceUpdate -firrtl.circuit "InstanceUpdate" { - firrtl.domain @ClockDomain - - firrtl.module @Foo(in %i : !firrtl.uint<1>) {} - - // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) - // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain - // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> - // CHECK: } - firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { - %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) - firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> - } -} - -// CHECK-LABEL: InstanceChoiceUpdate -firrtl.circuit "InstanceChoiceUpdate" { - firrtl.domain @ClockDomain - - firrtl.option @Option { - firrtl.option_case @X - firrtl.option_case @Y - } - - firrtl.module @Foo(in %i : !firrtl.uint<1>) {} - firrtl.module @Bar(in %i : !firrtl.uint<1>) {} - firrtl.module @Baz(in %i : !firrtl.uint<1>) {} - - // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) - // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain - // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> - // CHECK: } - firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { - %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) - firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> - } -} - -// CHECK-LABEL: ConstantInMultipleDomains -firrtl.circuit "ConstantInMultipleDomains" { - firrtl.domain @ClockDomain - - firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) - - firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { - %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> - %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) - firrtl.domain.define %x_A, %A - firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> - - %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) - firrtl.domain.define %y_A, %B - firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> - } -} - -firrtl.circuit "Top" { - firrtl.domain @ClockDomain - firrtl.extmodule @Foo( - in ClockDomain : !firrtl.domain of @ClockDomain, - in i: !firrtl.uint<1> domains [ClockDomain], - out o : !firrtl.uint<1> domains [ClockDomain] - ) - - firrtl.module @Top(in %ClockDomain : !firrtl.domain of @ClockDomain ) { - %foo1_ClockDomain, %foo1_i, %foo1_o = firrtl.instance foo1 @Foo( - in ClockDomain : !firrtl.domain of @ClockDomain, - in i: !firrtl.uint<1> domains [ClockDomain], - out o : !firrtl.uint<1> domains [ClockDomain] - ) - - %foo2_ClockDomain, %foo2_i, %foo2_o = firrtl.instance foo2 @Foo( - in ClockDomain : !firrtl.domain of @ClockDomain, - in i: !firrtl.uint<1> domains [ClockDomain], - out o : !firrtl.uint<1> domains [ClockDomain] - ) - - firrtl.domain.define %foo1_ClockDomain, %ClockDomain - firrtl.matchingconnect %foo2_i, %foo1_o : !firrtl.uint<1> - firrtl.matchingconnect %foo1_i, %foo2_o : !firrtl.uint<1> - } -} From b9f0e228bc74f1d50da6c4b7fc30f7345a76c7d0 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 25 Nov 2025 16:15:44 -0500 Subject: [PATCH 14/14] Add test --- .../FIRRTL/Transforms/InferDomains.cpp | 28 +- .../infer-domains-infer-all-errors.mlir | 208 +++++++++++ .../FIRRTL/infer-domains-infer-all.mlir | 323 ++++++++++++++++++ .../FIRRTL/infer-domains-infer-errors.mlir | 48 +++ test/Dialect/FIRRTL/infer-domains-infer.mlir | 38 +++ 5 files changed, 630 insertions(+), 15 deletions(-) create mode 100644 test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir create mode 100644 test/Dialect/FIRRTL/infer-domains-infer-all.mlir create mode 100644 test/Dialect/FIRRTL/infer-domains-infer-errors.mlir create mode 100644 test/Dialect/FIRRTL/infer-domains-infer.mlir diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 05899a7ab8c6..22487f5e8822 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -211,7 +211,7 @@ struct TermBase : Term { struct VariableTerm : public TermBase { VariableTerm() : leader(nullptr) {} VariableTerm(Term *leader) : leader(leader) {} - Term *leader; + Term *leader = nullptr; }; /// A concrete value defined in the IR. @@ -246,12 +246,14 @@ Term *find(Term *x) { } /// A helper for assigning low numeric IDs to variables for user-facing output. -struct VariableIDTable { +class VariableIDTable { +public: size_t get(VariableTerm *term) { auto [it, inserted] = table.insert({term, table.size() + 1}); return it->second; } +private: DenseMap table; }; @@ -303,13 +305,8 @@ raw_ostream &dump(raw_ostream &out, const ValueTerm *term) { // NOLINTNEXTLINE(misc-no-recursion) raw_ostream &dump(raw_ostream &out, const RowTerm *term) { out << "row@" << term << "{"; - bool first = true; - for (auto *element : term->elements) { - if (!first) - out << ", "; - dump(out, element); - first = false; - } + llvm::interleaveComma(term->elements, out, + [&](auto element) { dump(out, element); }); out << "}"; return out; } @@ -708,6 +705,7 @@ template LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op) { + llvm::errs() << "ins=" << op << "\n"; auto numDomainTypes = info.getNumDomains(); DenseMap domainPortTypeIDTable; auto domainInfo = op.getDomainInfoAttr(); @@ -966,7 +964,7 @@ void createModuleDomainPorts(const DomainInfo &info, TermAllocator &allocator, if (isPort(module, value)) continue; - // The domain is defined internally. If there value is already exported, + // The domain is defined internally. If the value is already exported, // or will be exported, we are done. if (exportTable.contains(value) || pendingExports.contains(value)) continue; @@ -1134,7 +1132,7 @@ LogicalResult updateModuleDomainInfo(const DomainInfo &info, continue; } - newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + newModuleDomainInfo[i] = ArrayAttr::get(context, {}); } result = ArrayAttr::get(module.getContext(), newModuleDomainInfo); @@ -1410,12 +1408,12 @@ LogicalResult checkAndInferModule(const DomainInfo &info, LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op) { - // llvm::errs() << "********\n"; - // llvm::errs() << *op << "\n"; - - if (auto module = dyn_cast(op)) { + + llvm::errs() << *op << "\n"; + if (auto module = dyn_cast(op)) { if (mode == InferDomainsMode::Check) return checkModule(info, module); + if (mode == InferDomainsMode::InferAll || module.isPrivate()) return inferModule(info, updateTable, module); diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir new file mode 100644 index 000000000000..53f3b58a52e5 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir @@ -0,0 +1,208 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s --verify-diagnostics + +// Port annotated with same domain type twice. +firrtl.circuit "DomainCrossOnPort" { + firrtl.domain @ClockDomain + firrtl.module @DomainCrossOnPort( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} + // expected-note @below {{1st instance: A}} + // expected-note @below {{2nd instance: B}} + in %p: !firrtl.uint<1> domains [%A, %B] + ) {} +} + +// Illegal domain crossing via connect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1> + } +} + +// Illegal domain crossing at matchingconnect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unable to infer domain of port, when port is driven by constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<1>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> + } +} + +// Unable to infer domain of port, when port is driven by arithmetic on constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<2>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) + firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> + } +} + +// Incomplete extmodule domain information. + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + + // expected-error @below {{missing "ClockDomain" association for port "i"}} + firrtl.extmodule @Top(in i: !firrtl.uint<1>) +} + +// Conflicting extmodule domain information. + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Top( + // expected-note @below {{associated with "ClockDomain" port "D1"}} + in D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{associated with "ClockDomain" port "D2"}} + in D2 : !firrtl.domain of @ClockDomain, + // expected-error @below {{ambiguous "ClockDomain" association for port "i"}} + in i: !firrtl.uint<1> domains [D1, D2] + ) +} + +// ----- + +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + +// Domain exported multiple times, this time with two outputs. + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Generator(out D: !firrtl.domain of @ClockDomain) + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "D1"}} + out %D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "D2"}} + out %D2 : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%D1], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + %gen_D = firrtl.instance gen @Generator(out D: !firrtl.domain of @ClockDomain) + // DI and DO are aliases + firrtl.domain.define %D1, %gen_D + firrtl.domain.define %D2, %gen_D + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + +// CHECK-LABEL: UndrivenInstanceDomainPort +firrtl.circuit "UndrivenInstanceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %foo_c = firrtl.instance foo @Foo(in c : !firrtl.domain of @ClockDomain) + } +} + +// CHECK-LABEL: UndrivenInstanceChoiceDomainPort +firrtl.circuit "UndrivenInstanceChoiceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + } + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + firrtl.extmodule @Bar(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceChoiceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %inst_c = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar } (in c : !firrtl.domain of @ClockDomain) + } +} + +// InstanceChoice: Each module has different domains inferred. +// TODO: this just relies on the op-verifier for instance choice ops. + +firrtl.circuit "ConflictingInstanceChoiceDomains" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + // Foo's "out" port takes on the domains of "in1". + firrtl.module @Foo(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in1 : !firrtl.uint<1> + } + + // Bar's "out" port takes on the domains of "in2". + // expected-note @below {{original module declared here}} + firrtl.module @Bar(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in2 : !firrtl.uint<1> + } + + firrtl.module @ConflictingInstanceChoiceDomains(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance_choice' op domain info for "out" must be [2 : ui32], but got [0 : ui32]}} + %inst_in1, %inst_in2, %inst_out = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Foo, @Y -> @Bar } (in in1: !firrtl.uint<1>, in in2: !firrtl.uint<1>, out out: !firrtl.uint<1>) + firrtl.connect %inst_in1, %in1 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %inst_in2, %in2 : !firrtl.uint<1>, !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir new file mode 100644 index 000000000000..1b7d465d34c7 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir @@ -0,0 +1,323 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s | FileCheck %s + +// Legal domain usage - no crossing. +// CHECK-LABEL: firrtl.circuit "LegalDomains" +firrtl.circuit "LegalDomains" { + firrtl.domain @ClockDomain + firrtl.module @LegalDomains( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // Connecting within the same domain is legal. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain inference through connections. +// CHECK-LABEL: firrtl.circuit "DomainInference" +firrtl.circuit "DomainInference" { + firrtl.domain @ClockDomain + firrtl.module @DomainInference( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + // CHECK: out %c: !firrtl.uint<1> domains [%A] + out %c: !firrtl.uint<1> + ) { + %b = firrtl.wire : !firrtl.uint<1> // No explicit domain + + // This should infer that %b is in domain %A. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // This should be legal since %b is now inferred to be in domain %A. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} + +// Unsafe domain cast +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" +firrtl.circuit "UnsafeDomainCast" { + firrtl.domain @ClockDomain + firrtl.module @UnsafeDomainCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> domains [%B] + ) { + // Unsafe cast from domain A to domain B. + %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> + + // This should be legal since we explicitly cast. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} + +// Domain sequence matching. +// CHECK-LABEL: firrtl.circuit "LegalSequences" +firrtl.circuit "LegalSequences" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @LegalSequences( + in %C: !firrtl.domain of @ClockDomain, + in %P: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%C, %P], + out %b: !firrtl.uint<1> domains [%C, %P] + ) { + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence order equivalence - should be legal +// CHECK-LABEL: SequenceOrderEquivalence +firrtl.circuit "SequenceOrderEquivalence" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceOrderEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B, %A] + ) { + // This should be legal since domain order doesn't matter in canonical representation + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence inference +// CHECK-LABEL: SequenceInference +firrtl.circuit "SequenceInference" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %d: !firrtl.uint<1> + ) { + %c = firrtl.wire : !firrtl.uint<1> + + // %c should infer domain sequence [%A, %B] + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + + // This should be legal since %c has inferred [%A, %B] + firrtl.matchingconnect %d, %c : !firrtl.uint<1> + } +} + +// Domain duplicate equivalence - should be legal. +// CHECK-LABEL: DuplicateDomainEquivalence +firrtl.circuit "DuplicateDomainEquivalence" { + firrtl.domain @ClockDomain + firrtl.module @DuplicateDomainEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // This should be legal since duplicate domains are canonicalized. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unsafe domain cast with sequences +// CHECK-LABEL: UnsafeSequenceCast +firrtl.circuit "UnsafeSequenceCast" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @UnsafeSequenceCast( + in %C1: !firrtl.domain of @ClockDomain, + in %C2: !firrtl.domain of @ClockDomain, + in %P1: !firrtl.domain of @PowerDomain, + in %i: !firrtl.uint<1> domains [%C1, %P1], + out %o: !firrtl.uint<1> domains [%C2, %P1] + ) { + %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> + firrtl.matchingconnect %o, %0 : !firrtl.uint<1> + } +} + +// Different port types domain inference. +// CHECK-LABEL: DifferentPortTypes +firrtl.circuit "DifferentPortTypes" { + firrtl.domain @ClockDomain + firrtl.module @DifferentPortTypes( + in %A: !firrtl.domain of @ClockDomain, + in %uint_input: !firrtl.uint<8> domains [%A], + in %sint_input: !firrtl.sint<4> domains [%A], + out %uint_output: !firrtl.uint<8>, + out %sint_output: !firrtl.sint<4> + ) { + firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> + firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> + } +} + +// Domain inference through wires. +// CHECK-LABEL: DomainInferenceThroughWires +firrtl.circuit "DomainInferenceThroughWires" { + firrtl.domain @ClockDomain + firrtl.module @DomainInferenceThroughWires( + in %A: !firrtl.domain of @ClockDomain, + in %input: !firrtl.uint<1> domains [%A], + // CHECK: out %output: !firrtl.uint<1> domains [%A] + out %output: !firrtl.uint<1> + ) { + %wire1 = firrtl.wire : !firrtl.uint<1> + %wire2 = firrtl.wire : !firrtl.uint<1> + + firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> + firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> + firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> + } +} + +// Export: add output domain port for domain created internally. +// CHECK-LABEL: ExportDomain +firrtl.circuit "ExportDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ExportDomain( + // CHECK: out %ClockDomain: !firrtl.domain of @ClockDomain + // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + // CHECK: firrtl.domain.define %ClockDomain, %foo_A + } +} + +// Export: Reuse already-exported domain. +// CHECK-LABEL: ReuseExportedDomain +firrtl.circuit "ReuseExportedDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ReuseExportedDomain( + out %A: !firrtl.domain of @ClockDomain, + // CHECK: out %o: !firrtl.uint<1> domains [%A] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + firrtl.domain.define %A, %foo_A + } +} + +// CHECK-LABEL: RegisterInference +firrtl.circuit "RegisterInference" { + firrtl.domain @ClockDomain + firrtl.module @RegisterInference( + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + // CHECK: in %d: !firrtl.uint<1> domains [%A] + in %d: !firrtl.uint<1>, + // CHECK: out %q: !firrtl.uint<1> domains [%A] + out %q: !firrtl.uint<1> + ) { + %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> + firrtl.matchingconnect %r, %d : !firrtl.uint<1> + firrtl.matchingconnect %q, %r : !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceUpdate +firrtl.circuit "InstanceUpdate" { + firrtl.domain @ClockDomain + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceChoiceUpdate +firrtl.circuit "InstanceChoiceUpdate" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + firrtl.module @Bar(in %i : !firrtl.uint<1>) {} + firrtl.module @Baz(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { + %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) + firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: ConstantInMultipleDomains +firrtl.circuit "ConstantInMultipleDomains" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + + firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %x_A, %A + firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> + + %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %y_A, %B + firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> + } +} + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + firrtl.extmodule @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.module @Top(in %ClockDomain : !firrtl.domain of @ClockDomain ) { + %foo1_ClockDomain, %foo1_i, %foo1_o = firrtl.instance foo1 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + %foo2_ClockDomain, %foo2_i, %foo2_o = firrtl.instance foo2 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.domain.define %foo1_ClockDomain, %ClockDomain + firrtl.matchingconnect %foo2_i, %foo1_o : !firrtl.uint<1> + firrtl.matchingconnect %foo1_i, %foo2_o : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir new file mode 100644 index 000000000000..7e6f8ca17735 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir @@ -0,0 +1,48 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer}))' %s --verify-diagnostics + +// in "infer" mode, infer-domains requires that the interfaces of public +// modules are fully annotated with domain associations, but will still +// perform domain inference on the body of a public module, and will do full +// inference for private modules. + +// This test suite checks for errors which do not occur when the mode is +// infer-all. + +// CHECK-LABEL: MissingDomain +firrtl.circuit "MissingDomain" { + firrtl.domain @ClockDomain + + firrtl.module @MissingDomain( + // expected-error @below {{missing "ClockDomain" association for port "x"}} + in %x: !firrtl.uint<1> + ) {} +} + +// CHECK-LABEL: MissingSecondDomain +firrtl.circuit "MissingSecondDomain" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @MissingSecondDomain( + in %c : !firrtl.domain of @ClockDomain, + // expected-error @below {{missing "PowerDomain" association for port "x"}} + in %x : !firrtl.uint<1> domains [%c] + ) {} +} + +// Test that domain crossing errors are still caught when in infer mode. +// Catching this involves processing the module without writing back to the IR. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer.mlir b/test/Dialect/FIRRTL/infer-domains-infer.mlir new file mode 100644 index 000000000000..d2778aefd35f --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer.mlir @@ -0,0 +1,38 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer}))' %s | FileCheck %s + +firrtl.circuit "InferOutputDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + + // CHECK: firrtl.module private @Bar(out %ClockDomain: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + // CHECK: firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> + // CHECK: firrtl.domain.define %ClockDomain, %foo_D + // CHECK: } + firrtl.module private @Bar(out %x : !firrtl.uint<1>) { + %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> + } + + // CHECK: firrtl.module @InferOutputDomain(out %D: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%D]) { + // CHECK: %bar_ClockDomain, %bar_x = firrtl.instance bar @Bar(out ClockDomain: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> + // CHECK: firrtl.domain.define %D, %bar_ClockDomain + // CHECK: } + firrtl.module @InferOutputDomain(out %D: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%D]) { + %bar_x = firrtl.instance bar @Bar(out x : !firrtl.uint<1>) + firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> + } +} + +// do not crash the InferDomains pass. This stems from the fact that "no domain +// information" can be represented as both an empty array `[]` and an empty +// array of arrays `[[]]`. +firrtl.circuit "EmptyDomainInfo" { + firrtl.domain @DomainKind + firrtl.module @EmptyDomainInfo(out %x: !firrtl.integer) { + %0 = firrtl.integer 5 + firrtl.propassign %x, %0 : !firrtl.integer + } +} \ No newline at end of file