diff --git a/include/circt/Dialect/FIRRTL/Passes.h b/include/circt/Dialect/FIRRTL/Passes.h index a3cb2b8c9900..0515cbae6456 100644 --- a/include/circt/Dialect/FIRRTL/Passes.h +++ b/include/circt/Dialect/FIRRTL/Passes.h @@ -69,6 +69,26 @@ enum class CompanionMode { Drop, }; +/// The mode for the InferDomains pass. +enum class InferDomainsMode { + /// Check domains with inference for private modules (default). + Infer, + /// Check domains without inference. + Check, + /// Check domains with inference for both public and private modules. + InferAll, +}; + +/// True if the mode indicates we should infer domains on public modules. +constexpr bool shouldInferPublicModules(InferDomainsMode mode) { + return mode == InferDomainsMode::InferAll; +} + +/// True if the mode indicates we should infer domains on private modules. +constexpr bool shouldInferPrivateModules(InferDomainsMode mode) { + return mode == InferDomainsMode::Infer || mode == InferDomainsMode::InferAll; +} + #define GEN_PASS_DECL #include "circt/Dialect/FIRRTL/Passes.h.inc" diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 15cf686d618c..3d7a47e84c5b 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -909,6 +909,31 @@ def CheckLayers : Pass<"firrtl-check-layers", "firrtl::CircuitOp"> { }]; } +def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { + let summary = "Infer and type check all firrtl domains"; + let description = [{ + This pass does domain inference on a FIRRTL circuit. The end result of this + is either a corrrctly domain-checked FIRRTL circuit or failure with verbose + error messages indicating why the FIRRTL circuit has illegal domain + constructs. + + E.g., this pass can be used to check for illegal clock-domain-crossings if + clock domains are specified for signals in the design. + }]; + let options = [Option<"mode", "mode", "InferDomainsMode", + "InferDomainsMode::Infer", "infer, check, infer-all.", + [{ + llvm::cl::values( + clEnumValN(InferDomainsMode::Infer, "infer", + "Check domains with inference for private modules"), + clEnumValN(InferDomainsMode::Check, "check", + "Check domains without inference"), + clEnumValN(InferDomainsMode::InferAll, "infer-all", + "Check domains with inference for both public and private " + "modules")) + }]>]; +} + def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { let summary = "lower domain information to properties"; let description = [{ diff --git a/include/circt/Firtool/Firtool.h b/include/circt/Firtool/Firtool.h index 28e5f6d5d897..22d5a335d237 100644 --- a/include/circt/Firtool/Firtool.h +++ b/include/circt/Firtool/Firtool.h @@ -23,6 +23,34 @@ namespace circt { namespace firtool { + +enum class DomainMode { + /// Disable domain checking. + Disable, + /// Check domains with inference for private modules. + Infer, + /// Check domains without inference. + Check, + /// Check domains with inference for both public and private modules. + InferAll, +}; + +/// Convert the "domain mode" firtool option to a "firrtl::InferDomainsMode", +/// the configuration for a pass. +constexpr std::optional +toInferDomainsPassMode(DomainMode mode) { + switch (mode) { + case DomainMode::Disable: + return std::nullopt; + case DomainMode::Infer: + return firrtl::InferDomainsMode::Infer; + case DomainMode::Check: + return firrtl::InferDomainsMode::Check; + case DomainMode::InferAll: + return firrtl::InferDomainsMode::InferAll; + } +} + //===----------------------------------------------------------------------===// // FirtoolOptions //===----------------------------------------------------------------------===// @@ -144,6 +172,8 @@ class FirtoolOptions { bool getEmitAllBindFiles() const { return emitAllBindFiles; } + DomainMode getDomainMode() const { return domainMode; } + // Setters, used by the CAPI FirtoolOptions &setOutputFilename(StringRef name) { outputFilename = name; @@ -393,6 +423,11 @@ class FirtoolOptions { return *this; } + FirtoolOptions &setdomainMode(DomainMode value) { + domainMode = value; + return *this; + } + private: std::string outputFilename; @@ -447,6 +482,7 @@ class FirtoolOptions { bool lintStaticAsserts; bool lintXmrsInDesign; bool emitAllBindFiles; + DomainMode domainMode; }; void registerFirtoolCLOptions(); diff --git a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt index f8a1ec9e4045..d24f11c7563b 100755 --- a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt +++ b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms GrandCentral.cpp IMConstProp.cpp IMDeadCodeElim.cpp + InferDomains.cpp InferReadWrite.cpp InferResets.cpp InferWidths.cpp diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp new file mode 100644 index 000000000000..22487f5e8822 --- /dev/null +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -0,0 +1,1447 @@ +//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// InferDomains implements FIRRTL domain inference and checking. This pass is a +// bottom-up transform acting on modules. For each module, we ensure there are +// no domain crossings, and we make explicit the domain associations of ports. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" +#include "circt/Dialect/FIRRTL/FIRRTLOps.h" +#include "circt/Dialect/FIRRTL/FIRRTLUtils.h" +#include "circt/Dialect/FIRRTL/Passes.h" +#include "circt/Support/Debug.h" +#include "circt/Support/Namespace.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "firrtl-infer-domains" + +namespace circt { +namespace firrtl { +#define GEN_PASS_DEF_INFERDOMAINS +#include "circt/Dialect/FIRRTL/Passes.h.inc" +} // namespace firrtl +} // namespace circt + +using namespace circt; +using namespace firrtl; + +namespace { + +//====-------------------------------------------------------------------------- +// Helpers. +//====-------------------------------------------------------------------------- + +using PortInsertions = SmallVector>; + +template +bool shouldInfer(T op, InferDomainsMode mode) { + return op.isPublic() ? shouldInferPublicModules(mode) + : shouldInferPrivateModules(mode); +} + +/// From a domain info attribute, get the domain-type of a domain value at +/// index i. +StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { + if (info.empty()) + return nullptr; + auto ref = cast(info[i]); + return ref.getAttr(); +} + +/// From a domain info attribute, get the row of associated domains for a +/// hardware value at index i. +auto getPortDomainAssociation(ArrayAttr info, size_t i) { + if (info.empty()) + return info.getAsRange(); + return cast(info[i]).getAsRange(); +} + +/// Return true if the value is a port on the module. +bool isPort(FModuleOp module, BlockArgument arg) { + return arg.getOwner()->getParentOp() == module; +} + +/// Return true if the value is a port on the module. +bool isPort(FModuleOp module, Value value) { + auto arg = dyn_cast(value); + if (!arg) + return false; + return isPort(module, arg); +} + +/// Returns true if the value is driven by a connect op. +bool isDriven(Value port) { + for (auto *user : port.getUsers()) + if (auto connect = dyn_cast(user)) + if (connect.getDest() == port) + return true; + return false; +} + +//====-------------------------------------------------------------------------- +// Global State. +//====-------------------------------------------------------------------------- + +/// Each domain type declared in the circuit is assigned a type-id, based on the +/// order of declaration. Domain associations for hardware values are +/// represented as a list, or row, of domains. The domains in a row are ordered +/// according to their type's id. +using DomainTypeID = size_t; + +/// Information about the domains in the circuit. Able to map domains to their +/// type ID, which in this pass is the canonical way to reference the type +/// of a domain, as well as provide fast access to domain ops +class DomainInfo { +public: + DomainInfo(CircuitOp circuit) { processCircuit(circuit); } + + ArrayRef getDomains() const { return domainTable; } + size_t getNumDomains() const { return domainTable.size(); } + DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } + + DomainTypeID getDomainTypeID(StringAttr name) const { + return typeIDTable.at(name); + } + + DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const { + return getDomainTypeID(ref.getAttr()); + } + + DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const { + auto name = getDomainPortTypeName(info, i); + return getDomainTypeID(name); + } + + DomainTypeID getDomainTypeID(Value value) const { + assert(isa(value.getType())); + if (auto arg = dyn_cast(value)) { + auto *block = arg.getOwner(); + auto *owner = block->getParentOp(); + auto module = cast(owner); + auto info = module.getDomainInfoAttr(); + auto i = arg.getArgNumber(); + return getDomainTypeID(info, i); + } + + auto result = dyn_cast(value); + auto *owner = result.getOwner(); + auto instance = cast(owner); + auto info = instance.getDomainInfoAttr(); + auto i = result.getResultNumber(); + return getDomainTypeID(info, i); + } + +private: + void processDomain(DomainOp op) { + auto index = domainTable.size(); + auto name = op.getNameAttr(); + domainTable.push_back(op); + typeIDTable.insert({name, index}); + } + + void processCircuit(CircuitOp circuit) { + for (auto decl : circuit.getOps()) + processDomain(decl); + } + + /// A map from domain type ID to op. + SmallVector domainTable; + + /// A map from domain name to type ID. + DenseMap typeIDTable; +}; + +/// Information about the changes made to the interface of a module, which can +/// be replayed onto an instance. +struct ModuleUpdateInfo { + /// The updated domain information for a module. + ArrayAttr portDomainInfo; + /// The domain ports which have been inserted into a module. + PortInsertions portInsertions; +}; + +using ModuleUpdateTable = DenseMap; + +/// Apply the port changes of a module onto an instance-like op. +template +T fixInstancePorts(T op, const ModuleUpdateInfo &update) { + auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); + clone.setDomainInfoAttr(update.portDomainInfo); + op->erase(); + return clone; +} + +//====-------------------------------------------------------------------------- +// Terms: Syntax for unifying domain and domain-rows. +//====-------------------------------------------------------------------------- + +/// The different sorts of terms in the unification engine. +enum class TermKind { + Variable, + Value, + Row, +}; + +/// A term in the unification engine. +struct Term { + constexpr Term(TermKind kind) : kind(kind) {} + TermKind kind; +}; + +/// Helper to define a term kind. +template +struct TermBase : Term { + static bool classof(const Term *term) { return term->kind == K; } + TermBase() : Term(K) {} +}; + +/// An unknown value. +struct VariableTerm : public TermBase { + VariableTerm() : leader(nullptr) {} + VariableTerm(Term *leader) : leader(leader) {} + Term *leader = nullptr; +}; + +/// A concrete value defined in the IR. +struct ValueTerm : public TermBase { + ValueTerm(Value value) : value(value) {} + Value getValue() const { return value; } + Value value; +}; + +/// A row of domains. +struct RowTerm : public TermBase { + RowTerm(ArrayRef elements) : elements(elements) {} + ArrayRef elements; +}; + +// NOLINTNEXTLINE(misc-no-recursion) +Term *find(Term *x) { + if (!x) + return nullptr; + + if (auto *var = dyn_cast(x)) { + if (var->leader == nullptr) + return var; + + auto *leader = find(var->leader); + if (leader != var->leader) + var->leader = leader; + return leader; + } + + return x; +} + +/// A helper for assigning low numeric IDs to variables for user-facing output. +class VariableIDTable { +public: + size_t get(VariableTerm *term) { + auto [it, inserted] = table.insert({term, table.size() + 1}); + return it->second; + } + +private: + DenseMap table; +}; + +// NOLINTNEXTLINE(misc-no-recursion) +void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, + Term *term) { + term = find(term); + if (auto *var = dyn_cast(term)) { + out << "?" << idTable.get(var); + return; + } + if (auto *val = dyn_cast(term)) { + auto value = val->value; + auto [name, rooted] = getFieldName(FieldRef(value, 0), false); + out << name; + return; + } + if (auto *row = dyn_cast(term)) { + bool first = true; + out << "["; + for (size_t i = 0, e = info.getNumDomains(); i < e; ++i) { + auto domainOp = info.getDomain(i); + if (!first) { + out << ", "; + first = false; + } + out << domainOp.getName() << ": "; + render(info, out, idTable, row->elements[i]); + } + out << "]"; + return; + } +} + +#ifndef NDEBUG + +raw_ostream &dump(llvm::raw_ostream &out, const Term *term); + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const VariableTerm *term) { + return out << "var@" << (void *)term << "{leader=" << term->leader << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const ValueTerm *term) { + return out << "val@" << term << "{" << term->value << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const RowTerm *term) { + out << "row@" << term << "{"; + llvm::interleaveComma(term->elements, out, + [&](auto element) { dump(out, element); }); + out << "}"; + return out; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const Term *term) { + if (!term) + return out << "null"; + if (auto *var = dyn_cast(term)) + return dump(out, var); + if (auto *val = dyn_cast(term)) + return dump(out, val); + if (auto *row = dyn_cast(term)) + return dump(out, row); + llvm_unreachable("unknown term"); +} +#endif // DEBUG + +LogicalResult unify(Term *lhs, Term *rhs); + +LogicalResult unify(VariableTerm *x, Term *y) { + assert(!x->leader); + x->leader = y; + return success(); +} + +LogicalResult unify(ValueTerm *xv, Term *y) { + if (auto *yv = dyn_cast(y)) { + yv->leader = xv; + return success(); + } + if (auto *yv = dyn_cast(y)) { + return success(xv == yv); + } + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(RowTerm *lhsRow, Term *rhs) { + if (auto *rhsVar = dyn_cast(rhs)) { + rhsVar->leader = lhsRow; + return success(); + } + if (auto *rhsRow = dyn_cast(rhs)) { + assert(lhsRow->elements.size() == rhsRow->elements.size()); + for (auto [x, y] : llvm::zip(lhsRow->elements, rhsRow->elements)) { + if (failed(unify(x, y))) + return failure(); + } + return success(); + } + + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(Term *lhs, Term *rhs) { + LLVM_DEBUG(auto &out = llvm::errs(); out << "unify x="; dump(out, lhs); + out << " y="; dump(out, rhs); out << "\n";); + if (!lhs || !rhs) + return success(); + lhs = find(lhs); + rhs = find(rhs); + if (lhs == rhs) + return success(); + if (auto *lhsVar = dyn_cast(lhs)) + return unify(lhsVar, rhs); + if (auto *lhsVal = dyn_cast(lhs)) + return unify(lhsVal, rhs); + if (auto *lhsRow = dyn_cast(lhs)) + return unify(lhsRow, rhs); + return failure(); +} + +void solve(Term *lhs, Term *rhs) { + auto result = unify(lhs, rhs); + (void)result; + assert(result.succeeded()); +} + +class TermAllocator { +public: + /// Allocate a row of fresh domain variables. + RowTerm *allocRow(size_t size) { + SmallVector elements; + elements.resize(size); + return allocRow(elements); + } + + /// Allocate a row of terms. + RowTerm *allocRow(ArrayRef elements) { + auto ds = allocArray(elements); + return alloc(ds); + } + + /// Allocate a fresh variable. + VariableTerm *allocVar() { return alloc(); } + + /// Allocate a concrete domain. + ValueTerm *allocVal(Value value) { return alloc(value); } + +private: + template + T *alloc(Args &&...args) { + static_assert(std::is_base_of_v, "T must be a term"); + return new (allocator) T(std::forward(args)...); + } + + ArrayRef allocArray(ArrayRef elements) { + auto size = elements.size(); + if (size == 0) + return {}; + + auto *result = allocator.Allocate(size); + llvm::uninitialized_copy(elements, result); + for (size_t i = 0; i < size; ++i) + if (!result[i]) + result[i] = alloc(); + + return ArrayRef(result, elements.size()); + } + + llvm::BumpPtrAllocator allocator; +}; + +//====-------------------------------------------------------------------------- +// DomainTable: A mapping from IR to terms. +//====-------------------------------------------------------------------------- + +/// Tracks domain infomation for IR values. +class DomainTable { +public: + /// If the domain value is an alias, returns the domain it aliases. + Value getOptUnderlyingDomain(Value value) const { + assert(isa(value.getType())); + auto *term = getOptTermForDomain(value); + if (auto *val = llvm::dyn_cast_if_present(term)) + return val->value; + return nullptr; + } + + /// Get the corresponding term for a domain in the IR, or null if unset. + Term *getOptTermForDomain(Value value) const { + assert(isa(value.getType())); + auto it = termTable.find(value); + if (it == termTable.end()) + return nullptr; + return find(it->second); + } + + /// Get the corresponding term for a domain in the IR. + Term *getTermForDomain(Value value) const { + auto *term = getOptTermForDomain(value); + assert(term); + return term; + } + + /// Record a mapping from domain in the IR to its corresponding term. + void setTermForDomain(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + assert(!termTable.contains(value)); + termTable.insert({value, term}); + } + + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, returns nullptr. + Term *getOptDomainAssociation(Value value) const { + assert(isa(value.getType())); + auto it = associationTable.find(value); + if (it == associationTable.end()) + return nullptr; + return find(it->second); + } + + /// For a hardware value, get the term which represents the row of associated + /// domains. + Term *getDomainAssociation(Value value) const { + auto *term = getOptDomainAssociation(value); + assert(term); + return term; + } + + /// Record a mapping from a hardware value in the IR to a term which + /// represents the row of domains it is associated with. + void setDomainAssociation(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + term = find(term); + associationTable.insert({value, term}); + } + +private: + /// Map from domains in the IR to their underlying term. + DenseMap termTable; + + /// A map from hardware values to their associated row of domains, as a term. + DenseMap associationTable; +}; + +//====-------------------------------------------------------------------------- +// Module processing: solve for the domain associations of hardware. +//====-------------------------------------------------------------------------- + +/// Get the corresponding term for a domain in the IR. If we don't know what the +/// term is, then map the domain in the IR to a variable term. +Term *getTermForDomain(TermAllocator &allocator, DomainTable &table, + Value value) { + assert(isa(value.getType())); + if (auto *term = table.getOptTermForDomain(value)) + return term; + auto *term = allocator.allocVar(); + table.setTermForDomain(value, term); + return term; +} + +/// Get the row of domains that a hardware value in the IR is associated with. +/// The returned term is forced to be at least a row. +RowTerm *getDomainAssociationAsRow(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + Value value) { + assert(isa(value.getType())); + auto *term = table.getOptDomainAssociation(value); + + // If the term is unknown, allocate a fresh row and set the association. + if (!term) { + auto *row = allocator.allocRow(info.getNumDomains()); + table.setDomainAssociation(value, row); + return row; + } + + // If the term is already a row, return it. + if (auto *row = dyn_cast(term)) + return row; + + // Otherwise, unify the term with a fresh row of domains. + if (auto *var = dyn_cast(term)) { + auto *row = allocator.allocRow(info.getNumDomains()); + solve(var, row); + return row; + } + + assert(false && "unhandled term type"); + return nullptr; +} + +template +void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, + size_t domainTypeID, Term *term1, + Term *term2) { + VariableIDTable idTable; + + auto portName = op.getPortNameAttr(i); + auto portLoc = op.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + + auto diag = emitError(portLoc); + diag << "illegal " << domainName << " crossing in port " << portName; + + auto ¬e1 = diag.attachNote(); + note1 << "1st instance: "; + render(info, note1, idTable, term1); + + auto ¬e2 = diag.attachNote(); + note2 << "2nd instance: "; + render(info, note2, idTable, term2); +} + +/// Emit an error when we fail to infer the concrete domain to drive to a +/// domain port. +template +void emitDomainPortInferenceError(T op, size_t i) { + auto name = op.getPortNameAttr(i); + auto diag = emitError(op->getLoc()); + auto info = op.getDomainInfo(); + diag << "unable to infer value for undriven domain port " << name; + for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { + if (auto assocs = dyn_cast(info[j])) { + for (auto assoc : assocs) { + if (i == cast(assoc).getValue()) { + auto name = op.getPortNameAttr(j); + auto loc = op.getPortLocation(j); + diag.attachNote(loc) << "associated with hardware port " << name; + break; + } + } + } + } +} + +template +void emitAmbiguousPortDomainAssociation(T op, size_t i) {} + +template +void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, + size_t typeID, size_t i) { + auto domainName = info.getDomain(typeID).getNameAttr(); + auto portName = op.getPortNameAttr(i); + emitError(op.getPortLocation(i)) + << "missing " << domainName << " association for port " << portName; +} + +/// Unify the associated domain rows of two terms. +LogicalResult unifyAssociations(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + Operation *op, Value lhs, Value rhs) { + LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n";); + + if (!lhs || !rhs) + return success(); + + if (lhs == rhs) + return success(); + + auto *lhsTerm = table.getOptDomainAssociation(lhs); + auto *rhsTerm = table.getOptDomainAssociation(rhs); + + if (lhsTerm) { + if (rhsTerm) { + if (failed(unify(lhsTerm, rhsTerm))) { + auto diag = op->emitOpError("illegal domain crossing in operation"); + auto ¬e1 = diag.attachNote(lhs.getLoc()); + + note1 << "1st operand has domains: "; + VariableIDTable idTable; + render(info, note1, idTable, lhsTerm); + + auto ¬e2 = diag.attachNote(rhs.getLoc()); + note2 << "2nd operand has domains: "; + render(info, note2, idTable, rhsTerm); + + return failure(); + } + } + table.setDomainAssociation(rhs, lhsTerm); + return success(); + } + + if (rhsTerm) { + table.setDomainAssociation(lhs, rhsTerm); + return success(); + } + + auto *var = allocator.allocVar(); + table.setDomainAssociation(lhs, var); + table.setDomainAssociation(rhs, var); + return success(); +} + +LogicalResult processModulePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + FModuleOp module) { + auto domainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + DenseMap domainTypeIDTable; + for (size_t i = 0; i < numPorts; ++i) { + BlockArgument port = module.getArgument(i); + + if (isa(port.getType())) { + auto typeID = info.getDomainTypeID(domainInfo, i); + domainTypeIDTable[i] = typeID; + if (module.getPortDirection(i) == Direction::In) { + table.setTermForDomain(port, allocator.allocVal(port)); + } + continue; + } + + auto portDomains = getPortDomainAssociation(domainInfo, i); + if (portDomains.empty()) + continue; + + SmallVector elements(info.getNumDomains()); + for (auto domainPortIndexAttr : portDomains) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = domainTypeIDTable[domainPortIndex]; + auto domainValue = module.getArgument(domainPortIndex); + auto *term = getTermForDomain(allocator, table, domainValue); + auto &slot = elements[domainTypeID]; + if (failed(unify(slot, term))) { + emitPortDomainCrossingError(info, module, i, domainTypeID, slot, term); + return failure(); + } + elements[domainTypeID] = term; + } + auto *row = allocator.allocRow(elements); + table.setDomainAssociation(port, row); + } + + return success(); +} + +template +LogicalResult processInstancePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + T op) { + llvm::errs() << "ins=" << op << "\n"; + auto numDomainTypes = info.getNumDomains(); + DenseMap domainPortTypeIDTable; + auto domainInfo = op.getDomainInfoAttr(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + Value port = op.getResult(i); + + if (isa(port.getType())) { + auto typeID = info.getDomainTypeID(domainInfo, i); + domainPortTypeIDTable[i] = typeID; + if (op.getPortDirection(i) == Direction::Out) { + table.setTermForDomain(port, allocator.allocVal(port)); + } + continue; + } + + if (!isa(port.getType())) + continue; + + // This is a port, which may have explicit domain information. Associate the + // port with a row of domains, where each element is derived from the domain + // associations recorded in the domain info attribute of the instance. + SmallVector elements(numDomainTypes); + auto associations = getPortDomainAssociation(domainInfo, i); + for (auto domainPortIndexAttr : associations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto typeID = domainPortTypeIDTable[domainPortIndex]; + auto *term = + getTermForDomain(allocator, table, op.getResult(domainPortIndex)); + elements[typeID] = term; + } + + // Confirm that we have complete domain information for the port. We can be + // missing information if, for example, this was an instance of an + // extmodule. + for (size_t domainTypeID = 0; domainTypeID < numDomainTypes; + ++domainTypeID) { + if (elements[domainTypeID]) + continue; + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto portName = op.getPortNameAttr(i); + op->emitOpError() << "missing " << domainName << " association for port " + << portName; + return failure(); + } + + table.setDomainAssociation(port, allocator.allocRow(elements)); + } + + return success(); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, InstanceOp op) { + auto module = op.getReferencedModuleNameAttr(); + auto lookup = updateTable.find(module); + if (lookup != updateTable.end()) + op = fixInstancePorts(op, lookup->second); + return processInstancePorts(info, allocator, table, op); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, + InstanceChoiceOp op) { + auto module = op.getDefaultTargetAttr().getAttr(); + auto lookup = updateTable.find(module); + if (lookup != updateTable.end()) + op = fixInstancePorts(op, lookup->second); + return processInstancePorts(info, allocator, table, op); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(info, allocator, table, op, op.getInput(), + op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(info, allocator, table, input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto typeID = info.getDomainTypeID(domain); + elements[typeID] = getTermForDomain(allocator, table, domain); + } + + auto *row = allocator.allocRow(elements); + table.setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *srcTerm = getTermForDomain(allocator, table, src); + auto *dstTerm = getTermForDomain(allocator, table, dst); + if (failed(unify(dstTerm, srcTerm))) { + VariableIDTable idTable; + auto diag = op->emitOpError("failed to propagate source to destination"); + auto ¬e1 = diag.attachNote(); + note1 << "destination has underlying value: "; + render(info, note1, idTable, dstTerm); + + auto ¬e2 = diag.attachNote(src.getLoc()); + note2 << "source has underlying value: "; + render(info, note2, idTable, srcTerm); + } + return success(); +} + +LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, Operation *op) { + LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); + if (auto instance = dyn_cast(op)) + return processOp(info, allocator, table, updateTable, instance); + if (auto instance = dyn_cast(op)) + return processOp(info, allocator, table, updateTable, instance); + if (auto cast = dyn_cast(op)) + return processOp(info, allocator, table, cast); + if (auto def = dyn_cast(op)) + return processOp(info, allocator, table, def); + + // For all other operations (including connections), propagate domains from + // operands to results. This is a conservative approach - all operands and + // results share the same domain associations. + Value lhs; + for (auto rhs : op->getOperands()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs))) + return failure(); + lhs = rhs; + } + for (auto rhs : op->getResults()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs))) + return failure(); + lhs = rhs; + } + return success(); +} + +LogicalResult processModuleBody(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + const ModuleUpdateTable &updateTable, + FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(processOp(info, allocator, table, updateTable, op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +/// Populate the domain table by processing the module. If the module has any +/// domain crossing errors, return failure. +LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, + const ModuleUpdateTable &updateTable, + FModuleOp module) { + if (failed(processModulePorts(info, allocator, table, module))) + return failure(); + + return processModuleBody(info, allocator, table, updateTable, module); +} + +} // namespace + +//====-------------------------------------------------------------------------- +// Module updating: write the computed domains back to the IR. +//====-------------------------------------------------------------------------- + +namespace { + +using ExportTable = DenseMap>; + +/// Build a table of exported domains: a map from domains defined internally, +/// to their set of aliasing output ports. +ExportTable initializeExportTable(const DomainTable &table, FModuleOp module) { + ExportTable exports; + size_t numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type)) + continue; + auto value = table.getOptUnderlyingDomain(port); + if (value) + exports[value].push_back(port); + } + + return exports; +} + +/// Copy the domain associations from the module domain info attribute into a +/// small vector. +SmallVector copyPortDomainAssociations(const DomainInfo &info, + ArrayAttr moduleDomainInfo, + size_t portIndex) { + SmallVector result(info.getNumDomains()); + auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); + for (auto domainPortIndexAttr : oldAssociations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex); + result[domainTypeID] = domainPortIndexAttr; + }; + return result; +} + +/// Add domain ports for any uninferred domains associated to hardware. +/// Returns the inserted ports, which will be used later to generalize the +/// instances of this module. +/// +/// If the port is hardware, we have to check the associated row of +/// domains. If any associated domain is a variable, we solve the variable +/// by generalizing the module with an additional input domain port. If any +/// associated domain is defined internally to the module, we have to add +/// an output domain port, to allow the domain to escape. +void createModuleDomainPorts(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, ExportTable &exportTable, + PortInsertions &insertions, FModuleOp module) { + DenseMap pendingSolutions; + llvm::MapVector pendingExports; + size_t inserted = 0; + auto numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + if (!isa(type)) + continue; + + auto *row = getDomainAssociationAsRow(info, allocator, table, port); + for (auto [typeID, term] : llvm::enumerate(row->elements)) { + auto *domain = find(term); + + if (auto *val = dyn_cast(domain)) { + auto value = val->value; + // If the domain value is defined inside the module body, we must output + // export the domain, so it may appear in the signature of the + // module. + if (isPort(module, value)) + continue; + + // The domain is defined internally. If the value is already exported, + // or will be exported, we are done. + if (exportTable.contains(value) || pendingExports.contains(value)) + continue; + + // We must insert a new output domain port. + auto domainDecl = info.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::Out; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending export. + auto exportedPortIndex = inserted + portInsertionPoint; + pendingExports[val->value] = exportedPortIndex; + ++inserted; + } + + if (auto *var = dyn_cast(domain)) { + if (pendingSolutions.contains(var)) + continue; + + // insert a new input domain port for the variable. + auto domainDecl = info.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::In; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + pendingSolutions[var] = solutionPortIndex; + ++inserted; + } + } + } + + // Put the domain ports in place. + module.insertPorts(insertions); + + // Solve the variables and record them as "self-exporting". + for (auto [var, portIndex] : pendingSolutions) { + auto port = module.getArgument(portIndex); + auto *solution = allocator.allocVal(port); + solve(var, solution); + exportTable[port].push_back(port); + } + + // Drive the pending exports. + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + for (auto [value, portIndex] : pendingExports) { + auto port = module.getArgument(portIndex); + DomainDefineOp::create(builder, port.getLoc(), port, value); + exportTable[value].push_back(port); + table.setTermForDomain(port, allocator.allocVal(value)); + } +} + +/// After generalizing the module, all domains should be solved. Reflect the +/// solved domain associations into the port domain info attribute. +LogicalResult updateModuleDomainInfo(const DomainInfo &info, + const DomainTable &table, + const ExportTable &exportTable, + ArrayAttr &result, FModuleOp module) { + // At this point, all domain variables mentioned in ports have been + // solved by generalizing the module (adding input domain ports). Now, we have + // to form the new port domain information for the module by examining the + // the associated domains of each port. + auto *context = module.getContext(); + auto numDomains = info.getNumDomains(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + auto oldModuleDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + SmallVector newModuleDomainInfo(numPorts); + + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + // If the port is an output domain, we may need to drive the output with + // a value. If we don't know what value to drive to the port, error. + if (isa(type)) { + // If the output port is not driven, drive it. + if (module.getPortDirection(i) == Direction::Out && !isDriven(port)) { + // Get the underlying value of the output port. + auto *term = table.getOptTermForDomain(port); + auto *val = llvm::dyn_cast_if_present(term); + if (!val) { + emitDomainPortInferenceError(module, i); + return failure(); + } + + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + continue; + } + + if (isa(type)) { + auto associations = + copyPortDomainAssociations(info, oldModuleDomainInfo, i); + auto *row = cast(table.getDomainAssociation(port)); + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) + continue; + + auto domain = cast(find(row->elements[domainTypeID]))->value; + auto &exports = exportTable.at(domain); + if (exports.empty()) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "private " << domainName << " association for port " + << portName; + diag.attachNote(domain.getLoc()) << "associated domain: " << domain; + return failure(); + } + + if (exports.size() > 1) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = info.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "ambiguous " << domainName << " association for port " + << portName; + for (auto arg : exports) { + auto name = module.getPortNameAttr(arg.getArgNumber()); + auto loc = module.getPortLocation(arg.getArgNumber()); + diag.attachNote(loc) << "candidate association " << name; + } + return failure(); + } + + auto argument = cast(exports[0]); + auto domainPortIndex = argument.getArgNumber(); + associations[domainTypeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), + domainPortIndex); + } + + newModuleDomainInfo[i] = ArrayAttr::get(context, associations); + continue; + } + + newModuleDomainInfo[i] = ArrayAttr::get(context, {}); + } + + result = ArrayAttr::get(module.getContext(), newModuleDomainInfo); + module.setDomainInfoAttr(result); + return success(); +} + +/// Update the ports of the module and record the change in the module update +/// table. +LogicalResult updateModulePorts(const DomainInfo &info, + TermAllocator &allocator, DomainTable &table, + ModuleUpdateTable &updateTable, FModuleOp op) { + // The export table tracks how domains are exported by the ports of the + // module. Initialize the export table by scanning the current ports. + auto exportTable = initializeExportTable(table, op); + + // Now, create any necessary domain ports. + PortInsertions portInsertions; + createModuleDomainPorts(info, allocator, table, exportTable, portInsertions, + op); + + // Update the domain info for the module's ports. + ArrayAttr portDomainInfo; + if (failed( + updateModuleDomainInfo(info, table, exportTable, portDomainInfo, op))) + return failure(); + + // Record the updated interface change in the update table. + auto &entry = updateTable[op.getModuleNameAttr()]; + entry.portDomainInfo = portDomainInfo; + entry.portInsertions = portInsertions; + return success(); +} + +template +LogicalResult updateInstance(const DomainTable &table, T op) { + auto *context = op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(op); + auto numPorts = op->getNumResults(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + auto direction = op.getPortDirection(i); + // If the port is an input domain, we may need to drive the input with + // a value. If we don't know what value to drive to the port, error. + if (isa(type) && direction == Direction::In && + !isDriven(port)) { + auto *term = table.getOptTermForDomain(port); + auto *val = llvm::dyn_cast_if_present(term); + if (!val) { + emitDomainPortInferenceError(op, i); + return failure(); + } + + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } + } + return success(); +} + +LogicalResult updateOp(const DomainTable &table, Operation *op) { + if (auto instance = dyn_cast(op)) + return updateInstance(table, instance); + if (auto instance = dyn_cast(op)) + return updateInstance(table, instance); + return success(); +} + +/// After updating the port domain associations, walk the body of the module +/// to fix up any child instance modules. +LogicalResult updateModuleBody(const DomainTable &table, FModuleOp module) { + auto result = success(); + module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { + if (failed(updateOp(table, op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +/// Write the domain associations recorded in the domain table back to the IR. +LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, + DomainTable &table, ModuleUpdateTable &updateTable, + FModuleOp op) { + if (failed(updateModulePorts(info, allocator, table, updateTable, op))) + return failure(); + + if (failed(updateModuleBody(table, op))) + return failure(); + + return success(); +} + +//====-------------------------------------------------------------------------- +// Domain Inference: solve domains and check for correctness,then update the +// IR to reflect the solved domains. +//====-------------------------------------------------------------------------- + +/// Solve for domains and then write the domain associations back to the IR. +LogicalResult inferModule(const DomainInfo &info, + ModuleUpdateTable &updateTable, FModuleOp module) { + TermAllocator allocator; + DomainTable table; + if (failed(processModule(info, allocator, table, updateTable, module))) + return failure(); + + return updateModule(info, allocator, table, updateTable, module); +} + +//====-------------------------------------------------------------------------- +// Domain Inference+Checking: Check that the interface of the module is fully +// annotated, before proceeding to run domain inference on the body of the +// module. +//====-------------------------------------------------------------------------- + +/// Check that a module's hardware ports have complete domain associations. +LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike module) { + auto numDomains = info.getNumDomains(); + auto domainInfo = module.getDomainInfoAttr(); + DenseMap typeIDTable; + for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { + auto type = module.getPortType(i); + + if (isa(type)) { + auto typeID = info.getDomainTypeID(domainInfo, i); + typeIDTable[i] = typeID; + continue; + } + + if (auto baseType = type_dyn_cast(type)) { + SmallVector associations(numDomains); + auto domains = getPortDomainAssociation(domainInfo, i); + for (auto index : domains) { + auto typeID = typeIDTable[index.getUInt()]; + auto &entry = associations[typeID]; + if (entry && entry != index) { + auto domainName = info.getDomain(typeID).getNameAttr(); + auto portName = module.getPortNameAttr(i); + auto diag = emitError(module.getPortLocation(i)) + << "ambiguous " << domainName << " association for port " + << portName; + + auto d1Loc = module.getPortLocation(entry.getUInt()); + auto d1Name = module.getPortNameAttr(entry.getUInt()); + diag.attachNote(d1Loc) + << "associated with " << domainName << " port " << d1Name; + + auto d2Loc = module.getPortLocation(index.getUInt()); + auto d2Name = module.getPortNameAttr(index.getUInt()); + diag.attachNote(d2Loc) + << "associated with " << domainName << " port " << d2Name; + } + entry = index; + } + + for (size_t typeID = 0; typeID < numDomains; ++typeID) { + auto association = associations[typeID]; + if (!association) { + emitMissingPortDomainAssociationError(info, module, typeID, i); + return failure(); + } + } + } + } + + return success(); +} + +/// Check that output domain ports are driven. +LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, + FModuleOp module) { + for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type) || + module.getPortDirection(i) != Direction::Out || isDriven(port)) + continue; + + auto name = module.getPortNameAttr(i); + emitError(module.getPortLocation(i)) << "undriven domain port " << name; + return failure(); + } + + return success(); +} + +/// Check that the input domain ports are driven. +template +LogicalResult checkInstanceDomainPortDrivers(T op) { + for (size_t i = 0, e = op.getNumResults(); i < e; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + if (!isa(type) || op.getPortDirection(i) != Direction::In || + isDriven(port)) + continue; + + auto name = op.getPortNameAttr(i); + emitError(op.getPortLocation(i)) << "undriven domain port " << name; + return failure(); + } + + return success(); +} + +LogicalResult checkOp(Operation *op) { + if (auto inst = dyn_cast(op)) + return checkInstanceDomainPortDrivers(inst); + if (auto inst = dyn_cast(op)) + return checkInstanceDomainPortDrivers(inst); + return success(); +} + +/// Check that instances under this module have driven domain input ports. +LogicalResult checkModuleBody(FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(checkOp(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +/// Check that a module's ports are fully annotated, before performing domain +/// inference on the module. +LogicalResult checkModule(const DomainInfo &info, FModuleOp module) { + if (failed(checkModulePorts(info, module))) + return failure(); + + if (failed(checkModuleDomainPortDrivers(info, module))) + return failure(); + + if (failed(checkModuleBody(module))) + return failure(); + + TermAllocator allocator; + DomainTable table; + ModuleUpdateTable updateTable; + return processModule(info, allocator, table, updateTable, module); +} + +/// Check that an extmodule's ports are fully annotated. +LogicalResult checkModule(const DomainInfo &info, FExtModuleOp module) { + return checkModulePorts(info, module); +} + +//====-------------------------------------------------------------------------- +// Hybrid Mode: Check the interface, then infer. +//====-------------------------------------------------------------------------- + +/// Check that a module's ports are fully annotated, before performing domain +/// inference on the module. We use this when private module interfaces are +/// inferred but public module interfaces are checked. +LogicalResult checkAndInferModule(const DomainInfo &info, + ModuleUpdateTable &updateTable, + FModuleOp module) { + if (failed(checkModulePorts(info, module))) + return failure(); + + return inferModule(info, updateTable, module); +} + +//===--------------------------------------------------------------------------- +// InferDomainsPass: Top-level pass implementation. +//===--------------------------------------------------------------------------- + +LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, + ModuleUpdateTable &updateTable, Operation *op) { + + llvm::errs() << *op << "\n"; + if (auto module = dyn_cast(op)) { + if (mode == InferDomainsMode::Check) + return checkModule(info, module); + + if (mode == InferDomainsMode::InferAll || module.isPrivate()) + return inferModule(info, updateTable, module); + + return checkAndInferModule(info, updateTable, module); + } + + if (auto extModule = dyn_cast(op)) + return checkModule(info, extModule); + + return success(); +} + +struct InferDomainsPass + : public circt::firrtl::impl::InferDomainsBase { + using InferDomainsBase::InferDomainsBase; + void runOnOperation() override { + CIRCT_DEBUG_SCOPED_PASS_LOGGER(this); + auto circuit = getOperation(); + auto &instanceGraph = getAnalysis(); + DomainInfo info(circuit); + DenseSet visited; + ModuleUpdateTable updateTable; + auto result = instanceGraph.walkPostOrder([&](auto &node) { + return runOnModuleLike(mode, info, updateTable, node.getModule()); + }); + if (failed(result)) + signalPassFailure(); + } +}; + +} // namespace diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index dcf768817451..2e01581b914a 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,6 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); + if (auto mode = toInferDomainsPassMode(opt.getDomainMode())) + pm.nest().addPass(firrtl::createInferDomains({*mode})); + return success(); } @@ -758,6 +761,20 @@ struct FirtoolCmdOptions { llvm::cl::desc("Emit bindfiles for private modules"), llvm::cl::init(false)}; + llvm::cl::opt domainMode{ + "domain-mode", llvm::cl::desc("Enable domain inference and checking"), + llvm::cl::init(firtool::DomainMode::Disable), + llvm::cl::values( + clEnumValN(firtool::DomainMode::Disable, "disable", + "Disable domain checking"), + clEnumValN(firtool::DomainMode::Infer, "infer", + "Check domains with inference for private modules"), + clEnumValN(firtool::DomainMode::Check, "check", + "Check domains without inference"), + clEnumValN(firtool::DomainMode::InferAll, "infer-all", + "Check domains with inference for both public and private " + "modules"))}; + //===----------------------------------------------------------------------=== // Lint options //===----------------------------------------------------------------------=== @@ -809,7 +826,8 @@ circt::firtool::FirtoolOptions::FirtoolOptions() disableCSEinClasses(false), selectDefaultInstanceChoice(false), symbolicValueLowering(verif::SymbolicValueLowering::ExtModule), disableWireElimination(false), lintStaticAsserts(true), - lintXmrsInDesign(true), emitAllBindFiles(false) { + lintXmrsInDesign(true), emitAllBindFiles(false), + domainMode(DomainMode::Disable) { if (!clOptions.isConstructed()) return; outputFilename = clOptions->outputFilename; @@ -862,4 +880,5 @@ circt::firtool::FirtoolOptions::FirtoolOptions() lintStaticAsserts = clOptions->lintStaticAsserts; lintXmrsInDesign = clOptions->lintXmrsInDesign; emitAllBindFiles = clOptions->emitAllBindFiles; + domainMode = clOptions->domainMode; } diff --git a/test/Dialect/FIRRTL/infer-domains-check-errors.mlir b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir new file mode 100644 index 000000000000..59782b1b7dd0 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-check-errors.mlir @@ -0,0 +1,82 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=check}))' %s --verify-diagnostics + +// in "check" mode, infer-domains will require that all ops are fully annotated +// with domains. No inference is run. + +// CHECK-LABEL: MissingDomain +firrtl.circuit "MissingDomain" { + firrtl.domain @ClockDomain + + firrtl.module @MissingDomain( + // expected-error @below {{missing "ClockDomain" association for port "x"}} + in %x: !firrtl.uint<1> + ) {} +} + +// CHECK-LABEL: MissingSecondDomain +firrtl.circuit "MissingSecondDomain" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @MissingSecondDomain( + in %c : !firrtl.domain of @ClockDomain, + // expected-error @below {{missing "PowerDomain" association for port "x"}} + in %x : !firrtl.uint<1> domains [%c] + ) {} +} + +// CHECK-LABEL: UndrivenOutputDomain +firrtl.circuit "UndrivenOutputDomain" { + firrtl.domain @ClockDomain + + firrtl.module @UndrivenOutputDomain( + // expected-error @below {{undriven domain port "c"}} + out %c : !firrtl.domain of @ClockDomain + ) {} +} + +// CHECK-LABEL: UndrivenInstanceDomainPort +firrtl.circuit "UndrivenInstanceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceDomainPort() { + // expected-error @below {{undriven domain port "c"}} + %foo_c = firrtl.instance foo @Foo(in c : !firrtl.domain of @ClockDomain) + } +} + +// CHECK-LABEL: UndrivenInstanceChoiceDomainPort +firrtl.circuit "UndrivenInstanceChoiceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + } + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + firrtl.extmodule @Bar(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceChoiceDomainPort() { + // expected-error @below {{undriven domain port "c"}} + %inst_c = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar } (in c : !firrtl.domain of @ClockDomain) + } +} + +// Test that domain crossing errors are still caught when in check-only mode. +// Catching this involves processing the module without writing back to the IR. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir new file mode 100644 index 000000000000..53f3b58a52e5 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-all-errors.mlir @@ -0,0 +1,208 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s --verify-diagnostics + +// Port annotated with same domain type twice. +firrtl.circuit "DomainCrossOnPort" { + firrtl.domain @ClockDomain + firrtl.module @DomainCrossOnPort( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} + // expected-note @below {{1st instance: A}} + // expected-note @below {{2nd instance: B}} + in %p: !firrtl.uint<1> domains [%A, %B] + ) {} +} + +// Illegal domain crossing via connect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1> + } +} + +// Illegal domain crossing at matchingconnect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unable to infer domain of port, when port is driven by constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<1>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> + } +} + +// Unable to infer domain of port, when port is driven by arithmetic on constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<2>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + // expected-error @below {{unable to infer value for undriven domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) + firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> + } +} + +// Incomplete extmodule domain information. + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + + // expected-error @below {{missing "ClockDomain" association for port "i"}} + firrtl.extmodule @Top(in i: !firrtl.uint<1>) +} + +// Conflicting extmodule domain information. + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Top( + // expected-note @below {{associated with "ClockDomain" port "D1"}} + in D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{associated with "ClockDomain" port "D2"}} + in D2 : !firrtl.domain of @ClockDomain, + // expected-error @below {{ambiguous "ClockDomain" association for port "i"}} + in i: !firrtl.uint<1> domains [D1, D2] + ) +} + +// ----- + +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + +// Domain exported multiple times, this time with two outputs. + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Generator(out D: !firrtl.domain of @ClockDomain) + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "D1"}} + out %D1 : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "D2"}} + out %D2 : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%D1], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + %gen_D = firrtl.instance gen @Generator(out D: !firrtl.domain of @ClockDomain) + // DI and DO are aliases + firrtl.domain.define %D1, %gen_D + firrtl.domain.define %D2, %gen_D + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + +// CHECK-LABEL: UndrivenInstanceDomainPort +firrtl.circuit "UndrivenInstanceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %foo_c = firrtl.instance foo @Foo(in c : !firrtl.domain of @ClockDomain) + } +} + +// CHECK-LABEL: UndrivenInstanceChoiceDomainPort +firrtl.circuit "UndrivenInstanceChoiceDomainPort" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + } + + firrtl.extmodule @Foo(in c : !firrtl.domain of @ClockDomain) + firrtl.extmodule @Bar(in c : !firrtl.domain of @ClockDomain) + + firrtl.module @UndrivenInstanceChoiceDomainPort() { + // expected-error @below {{unable to infer value for undriven domain port "c"}} + %inst_c = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar } (in c : !firrtl.domain of @ClockDomain) + } +} + +// InstanceChoice: Each module has different domains inferred. +// TODO: this just relies on the op-verifier for instance choice ops. + +firrtl.circuit "ConflictingInstanceChoiceDomains" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + // Foo's "out" port takes on the domains of "in1". + firrtl.module @Foo(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in1 : !firrtl.uint<1> + } + + // Bar's "out" port takes on the domains of "in2". + // expected-note @below {{original module declared here}} + firrtl.module @Bar(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + firrtl.connect %out, %in2 : !firrtl.uint<1> + } + + firrtl.module @ConflictingInstanceChoiceDomains(in %in1: !firrtl.uint<1>, in %in2: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance_choice' op domain info for "out" must be [2 : ui32], but got [0 : ui32]}} + %inst_in1, %inst_in2, %inst_out = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Foo, @Y -> @Bar } (in in1: !firrtl.uint<1>, in in2: !firrtl.uint<1>, out out: !firrtl.uint<1>) + firrtl.connect %inst_in1, %in1 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %inst_in2, %in2 : !firrtl.uint<1>, !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir new file mode 100644 index 000000000000..1b7d465d34c7 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir @@ -0,0 +1,323 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer-all}))' %s | FileCheck %s + +// Legal domain usage - no crossing. +// CHECK-LABEL: firrtl.circuit "LegalDomains" +firrtl.circuit "LegalDomains" { + firrtl.domain @ClockDomain + firrtl.module @LegalDomains( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // Connecting within the same domain is legal. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain inference through connections. +// CHECK-LABEL: firrtl.circuit "DomainInference" +firrtl.circuit "DomainInference" { + firrtl.domain @ClockDomain + firrtl.module @DomainInference( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + // CHECK: out %c: !firrtl.uint<1> domains [%A] + out %c: !firrtl.uint<1> + ) { + %b = firrtl.wire : !firrtl.uint<1> // No explicit domain + + // This should infer that %b is in domain %A. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // This should be legal since %b is now inferred to be in domain %A. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} + +// Unsafe domain cast +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" +firrtl.circuit "UnsafeDomainCast" { + firrtl.domain @ClockDomain + firrtl.module @UnsafeDomainCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> domains [%B] + ) { + // Unsafe cast from domain A to domain B. + %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> + + // This should be legal since we explicitly cast. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} + +// Domain sequence matching. +// CHECK-LABEL: firrtl.circuit "LegalSequences" +firrtl.circuit "LegalSequences" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @LegalSequences( + in %C: !firrtl.domain of @ClockDomain, + in %P: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%C, %P], + out %b: !firrtl.uint<1> domains [%C, %P] + ) { + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence order equivalence - should be legal +// CHECK-LABEL: SequenceOrderEquivalence +firrtl.circuit "SequenceOrderEquivalence" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceOrderEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B, %A] + ) { + // This should be legal since domain order doesn't matter in canonical representation + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence inference +// CHECK-LABEL: SequenceInference +firrtl.circuit "SequenceInference" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %d: !firrtl.uint<1> + ) { + %c = firrtl.wire : !firrtl.uint<1> + + // %c should infer domain sequence [%A, %B] + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + + // This should be legal since %c has inferred [%A, %B] + firrtl.matchingconnect %d, %c : !firrtl.uint<1> + } +} + +// Domain duplicate equivalence - should be legal. +// CHECK-LABEL: DuplicateDomainEquivalence +firrtl.circuit "DuplicateDomainEquivalence" { + firrtl.domain @ClockDomain + firrtl.module @DuplicateDomainEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // This should be legal since duplicate domains are canonicalized. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unsafe domain cast with sequences +// CHECK-LABEL: UnsafeSequenceCast +firrtl.circuit "UnsafeSequenceCast" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @UnsafeSequenceCast( + in %C1: !firrtl.domain of @ClockDomain, + in %C2: !firrtl.domain of @ClockDomain, + in %P1: !firrtl.domain of @PowerDomain, + in %i: !firrtl.uint<1> domains [%C1, %P1], + out %o: !firrtl.uint<1> domains [%C2, %P1] + ) { + %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> + firrtl.matchingconnect %o, %0 : !firrtl.uint<1> + } +} + +// Different port types domain inference. +// CHECK-LABEL: DifferentPortTypes +firrtl.circuit "DifferentPortTypes" { + firrtl.domain @ClockDomain + firrtl.module @DifferentPortTypes( + in %A: !firrtl.domain of @ClockDomain, + in %uint_input: !firrtl.uint<8> domains [%A], + in %sint_input: !firrtl.sint<4> domains [%A], + out %uint_output: !firrtl.uint<8>, + out %sint_output: !firrtl.sint<4> + ) { + firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> + firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> + } +} + +// Domain inference through wires. +// CHECK-LABEL: DomainInferenceThroughWires +firrtl.circuit "DomainInferenceThroughWires" { + firrtl.domain @ClockDomain + firrtl.module @DomainInferenceThroughWires( + in %A: !firrtl.domain of @ClockDomain, + in %input: !firrtl.uint<1> domains [%A], + // CHECK: out %output: !firrtl.uint<1> domains [%A] + out %output: !firrtl.uint<1> + ) { + %wire1 = firrtl.wire : !firrtl.uint<1> + %wire2 = firrtl.wire : !firrtl.uint<1> + + firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> + firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> + firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> + } +} + +// Export: add output domain port for domain created internally. +// CHECK-LABEL: ExportDomain +firrtl.circuit "ExportDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ExportDomain( + // CHECK: out %ClockDomain: !firrtl.domain of @ClockDomain + // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + // CHECK: firrtl.domain.define %ClockDomain, %foo_A + } +} + +// Export: Reuse already-exported domain. +// CHECK-LABEL: ReuseExportedDomain +firrtl.circuit "ReuseExportedDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + + firrtl.module @ReuseExportedDomain( + out %A: !firrtl.domain of @ClockDomain, + // CHECK: out %o: !firrtl.uint<1> domains [%A] + out %o: !firrtl.uint<1> + ) { + %foo_A, %foo_o = firrtl.instance foo @Foo( + out A: !firrtl.domain of @ClockDomain, + out o: !firrtl.uint<1> domains [A] + ) + firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> + firrtl.domain.define %A, %foo_A + } +} + +// CHECK-LABEL: RegisterInference +firrtl.circuit "RegisterInference" { + firrtl.domain @ClockDomain + firrtl.module @RegisterInference( + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + // CHECK: in %d: !firrtl.uint<1> domains [%A] + in %d: !firrtl.uint<1>, + // CHECK: out %q: !firrtl.uint<1> domains [%A] + out %q: !firrtl.uint<1> + ) { + %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> + firrtl.matchingconnect %r, %d : !firrtl.uint<1> + firrtl.matchingconnect %q, %r : !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceUpdate +firrtl.circuit "InstanceUpdate" { + firrtl.domain @ClockDomain + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceChoiceUpdate +firrtl.circuit "InstanceChoiceUpdate" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + firrtl.module @Bar(in %i : !firrtl.uint<1>) {} + firrtl.module @Baz(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { + %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) + firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: ConstantInMultipleDomains +firrtl.circuit "ConstantInMultipleDomains" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + + firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %x_A, %A + firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> + + %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %y_A, %B + firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> + } +} + +firrtl.circuit "Top" { + firrtl.domain @ClockDomain + firrtl.extmodule @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.module @Top(in %ClockDomain : !firrtl.domain of @ClockDomain ) { + %foo1_ClockDomain, %foo1_i, %foo1_o = firrtl.instance foo1 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + %foo2_ClockDomain, %foo2_i, %foo2_o = firrtl.instance foo2 @Foo( + in ClockDomain : !firrtl.domain of @ClockDomain, + in i: !firrtl.uint<1> domains [ClockDomain], + out o : !firrtl.uint<1> domains [ClockDomain] + ) + + firrtl.domain.define %foo1_ClockDomain, %ClockDomain + firrtl.matchingconnect %foo2_i, %foo1_o : !firrtl.uint<1> + firrtl.matchingconnect %foo1_i, %foo2_o : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir new file mode 100644 index 000000000000..7e6f8ca17735 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir @@ -0,0 +1,48 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer}))' %s --verify-diagnostics + +// in "infer" mode, infer-domains requires that the interfaces of public +// modules are fully annotated with domain associations, but will still +// perform domain inference on the body of a public module, and will do full +// inference for private modules. + +// This test suite checks for errors which do not occur when the mode is +// infer-all. + +// CHECK-LABEL: MissingDomain +firrtl.circuit "MissingDomain" { + firrtl.domain @ClockDomain + + firrtl.module @MissingDomain( + // expected-error @below {{missing "ClockDomain" association for port "x"}} + in %x: !firrtl.uint<1> + ) {} +} + +// CHECK-LABEL: MissingSecondDomain +firrtl.circuit "MissingSecondDomain" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @MissingSecondDomain( + in %c : !firrtl.domain of @ClockDomain, + // expected-error @below {{missing "PowerDomain" association for port "x"}} + in %x : !firrtl.uint<1> domains [%c] + ) {} +} + +// Test that domain crossing errors are still caught when in infer mode. +// Catching this involves processing the module without writing back to the IR. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer.mlir b/test/Dialect/FIRRTL/infer-domains-infer.mlir new file mode 100644 index 000000000000..d2778aefd35f --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-infer.mlir @@ -0,0 +1,38 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains{mode=infer}))' %s | FileCheck %s + +firrtl.circuit "InferOutputDomain" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + + // CHECK: firrtl.module private @Bar(out %ClockDomain: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + // CHECK: firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> + // CHECK: firrtl.domain.define %ClockDomain, %foo_D + // CHECK: } + firrtl.module private @Bar(out %x : !firrtl.uint<1>) { + %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [D]) + firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> + } + + // CHECK: firrtl.module @InferOutputDomain(out %D: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%D]) { + // CHECK: %bar_ClockDomain, %bar_x = firrtl.instance bar @Bar(out ClockDomain: !firrtl.domain of @ClockDomain, out x: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> + // CHECK: firrtl.domain.define %D, %bar_ClockDomain + // CHECK: } + firrtl.module @InferOutputDomain(out %D: !firrtl.domain of @ClockDomain, out %x: !firrtl.uint<1> domains [%D]) { + %bar_x = firrtl.instance bar @Bar(out x : !firrtl.uint<1>) + firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> + } +} + +// do not crash the InferDomains pass. This stems from the fact that "no domain +// information" can be represented as both an empty array `[]` and an empty +// array of arrays `[[]]`. +firrtl.circuit "EmptyDomainInfo" { + firrtl.domain @DomainKind + firrtl.module @EmptyDomainInfo(out %x: !firrtl.integer) { + %0 = firrtl.integer 5 + firrtl.propassign %x, %0 : !firrtl.integer + } +} \ No newline at end of file