Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelogs/unreleased/th__specify_main_via_attr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
changed:
- The main struct is now specified via the "llzk.main" attribute on the top-level module, rather than being hardcoded to a struct named "Main".
12 changes: 5 additions & 7 deletions include/llzk-c/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ extern "C" {
/// to a circom signal or AIR/PLONK column, opposed to intermediate values or other expressions.
extern const char *LLZK_COMPONENT_NAME_SIGNAL;

/// Symbol name for the main entry point struct/component (if any). There are additional
/// restrictions on the struct with this name:
/// 1. It cannot have struct parameters.
/// 2. The parameter types of its functions (besides the required "self" parameter) can
/// only be `struct<Signal>` or `array<.. x struct<Signal>>`.
extern const char *LLZK_COMPONENT_NAME_MAIN;

/// Symbol name for the witness generation (and resp. constraint generation) functions within a
/// component.
extern const char *LLZK_FUNC_NAME_COMPUTE;
Expand All @@ -38,6 +31,11 @@ extern const char *LLZK_FUNC_NAME_CONSTRAIN;
/// Name of the attribute on the top-level ModuleOp that specifies the IR language name.
extern const char *LLZK_LANG_ATTR_NAME;

/// Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
/// This attribute can appear zero or one times on the top-level ModuleOp and is associated with
/// a `TypeAttr` specifying the `StructType` of the main struct.
extern const char *LLZK_MAIN_ATTR_NAME;

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def StructCleanupMode
// use-def chain from some concrete struct are deleted.
I32EnumAttrCase<"ConcreteAsRoot", 2, "concrete-as-root">,
// MainAsRoot: All structs that cannot be reached by a
// use-def chain from the "Main" struct are deleted.
// use-def chain from the main struct are deleted.
I32EnumAttrCase<"MainAsRoot", 3, "main-as-root">,
]> {
let cppNamespace = "::llzk::polymorphic";
Expand Down
2 changes: 1 addition & 1 deletion include/llzk/Dialect/Struct/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def LLZK_StructDefOp
/// any surrounding module scopes.
::mlir::SymbolRefAttr getFullyQualifiedName();

/// Return `true` iff this StructDefOp is named "Main".
/// Return `true` iff this StructDefOp is the main struct. See `llzk::MAIN_ATTR_NAME`.
bool isMainComponent();
}];

Expand Down
12 changes: 5 additions & 7 deletions include/llzk/Util/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ namespace llzk {
/// to a circom signal or AIR/PLONK column, opposed to intermediate values or other expressions.
constexpr char COMPONENT_NAME_SIGNAL[] = "Signal";

/// Symbol name for the main entry point struct/component (if any). There are additional
/// restrictions on the struct with this name:
/// 1. It cannot have struct parameters.
/// 2. The parameter types of its functions (besides the required "self" parameter) can
/// only be `struct<Signal>` or `array<.. x struct<Signal>>`.
constexpr char COMPONENT_NAME_MAIN[] = "Main";

/// Symbol name for the witness generation (and resp. constraint generation) functions within a
/// component.
constexpr char FUNC_NAME_COMPUTE[] = "compute";
Expand All @@ -31,4 +24,9 @@ constexpr char FUNC_NAME_PRODUCT[] = "product";
/// Name of the attribute on the top-level ModuleOp that specifies the IR language name.
constexpr char LANG_ATTR_NAME[] = "veridise.lang";

/// Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
/// This attribute can appear zero or one times on the top-level ModuleOp and is associated with
/// a `TypeAttr` specifying the `StructType` of the main struct.
constexpr char MAIN_ATTR_NAME[] = "llzk.main";

} // namespace llzk
13 changes: 13 additions & 0 deletions include/llzk/Util/SymbolHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ getPathFromTopRoot(component::FieldDefOp &to, mlir::ModuleOp *foundRoot = nullpt
mlir::FailureOr<mlir::SymbolRefAttr>
getPathFromTopRoot(function::FuncDefOp &to, mlir::ModuleOp *foundRoot = nullptr);

/// @brief Lookup the `StructType` of the main instance.
///
/// This is specified by a `TypeAttr` on the top-level module with the key `LLZK_MAIN_ATTR_NAME`
/// and is optional, in which case the result will be `success(nullptr)`.
mlir::FailureOr<llzk::component::StructType> getMainInstanceType(mlir::Operation *lookupFrom);

/// @brief Lookup the `StructDefOp` of the main instance.
///
/// This is specified by a `TypeAttr` on the top-level module with the key `LLZK_MAIN_ATTR_NAME`
/// and is optional, in which case the result will be `success(nullptr)`.
mlir::FailureOr<SymbolLookupResult<llzk::component::StructDefOp>>
getMainInstanceDef(mlir::SymbolTableCollection &symbolTable, mlir::Operation *lookupFrom);

/// @brief Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers
/// @tparam T the type of symbol being resolved (e.g., function::FuncDefOp)
/// @param symbolTable
Expand Down
2 changes: 1 addition & 1 deletion lib/CAPI/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "llzk-c/Constants.h"

const char *LLZK_COMPONENT_NAME_SIGNAL = llzk::COMPONENT_NAME_SIGNAL;
const char *LLZK_COMPONENT_NAME_MAIN = llzk::COMPONENT_NAME_MAIN;
const char *LLZK_FUNC_NAME_COMPUTE = llzk::FUNC_NAME_COMPUTE;
const char *LLZK_FUNC_NAME_CONSTRAIN = llzk::FUNC_NAME_CONSTRAIN;
const char *LLZK_LANG_ATTR_NAME = llzk::LANG_ATTR_NAME;
const char *LLZK_MAIN_ATTR_NAME = llzk::MAIN_ATTR_NAME;
20 changes: 12 additions & 8 deletions lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,7 @@ class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<Flatte
}

inline LogicalResult runOn(ModuleOp modOp) {
// If the cleanup mode is set to remove anything not reachable from the "Main" struct, do an
// If the cleanup mode is set to remove anything not reachable from the main struct, do an
// initial pass to remove things that are not reachable (as an optimization) because creating
// an instantiated version of a struct will not cause something to become reachable that was
// not already reachable in parameterized form.
Expand Down Expand Up @@ -1895,22 +1895,26 @@ class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<Flatte
Step5_Cleanup::FromKeepSet cleaner(
rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
);
StructDefOp main =
cleaner.tables.getSymbolTable(rootMod).lookup<StructDefOp>(COMPONENT_NAME_MAIN);
FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
getMainInstanceDef(cleaner.tables, rootMod.getOperation());
if (failed(mainOpt)) {
return failure();
}
SymbolLookupResult<StructDefOp> main = mainOpt.value();
if (emitWarning && !main) {
// Emit warning if there is no "Main" because all structs may be removed (only structs that
// are reachable from a global def or free function will be preserved since those constructs
// are not candidate for removal in this pass).
// Emit warning if there is no main specified because all structs may be removed (only
// structs that are reachable from a global def or free function will be preserved since
// those constructs are not candidate for removal in this pass).
rootMod.emitWarning()
.append(
"using option '", cleanupMode.getArgStr(), '=',
stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
COMPONENT_NAME_MAIN, "\" struct may remove all structs!"
MAIN_ATTR_NAME, "\" attribute on the top-level module may remove all structs!"
)
.report();
}
return cleaner.eraseUnreachableFrom(
main ? ArrayRef<StructDefOp> {main} : ArrayRef<StructDefOp> {}
main ? ArrayRef<StructDefOp> {*main} : ArrayRef<StructDefOp> {}
);
}
};
Expand Down
46 changes: 24 additions & 22 deletions lib/Dialect/Struct/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ LogicalResult StructDefOp::verifySymbolUses(SymbolTableCollection &tables) {

namespace {

inline LogicalResult checkMainFuncParamType(Type pType, FuncDefOp inFunc, bool appendSelf) {
inline LogicalResult
checkMainFuncParamType(Type pType, FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
if (isSignalType(pType)) {
return success();
} else if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
Expand All @@ -227,11 +228,11 @@ inline LogicalResult checkMainFuncParamType(Type pType, FuncDefOp inFunc, bool a
}
}

std::string message = buildStringViaCallback([&inFunc, appendSelf](llvm::raw_ostream &ss) {
ss << "\"@" << COMPONENT_NAME_MAIN << "\" component \"@" << inFunc.getSymName()
std::string message = buildStringViaCallback([&inFunc, appendSelfType](llvm::raw_ostream &ss) {
ss << "main entry component \"@" << inFunc.getSymName()
<< "\" function parameters must be one of: {";
if (appendSelf) {
ss << "!" << StructType::name << "<@" << COMPONENT_NAME_MAIN << ">, ";
if (appendSelfType.has_value()) {
ss << appendSelfType.value() << ", ";
}
ss << "!" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL << ">, ";
ss << "!" << ArrayType::name << "<.. x !" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL
Expand All @@ -254,22 +255,17 @@ inline LogicalResult verifyStructComputeConstrain(
ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
if (structDef.isMainComponent()) {
// Verify that the Struct has no parameters.
if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
return structDef.emitError().append(
"The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
);
}
// Verify the input parameter types are legal. The error message is explicit about what types
// are allowed so there is no benefit to report multiple errors if more than one parameter in
// the referenced function has an illegal type.
for (Type t : computeParams) {
if (failed(checkMainFuncParamType(t, computeFunc, false))) {
if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
auto appendSelf = std::make_optional(structDef.getType());
for (Type t : constrainParams) {
if (failed(checkMainFuncParamType(t, constrainFunc, true))) {
if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
Expand All @@ -295,16 +291,14 @@ inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp produc
assert(productFunc.hasAllowWitnessAttr());

// Verify parameter types are valid
ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
if (structDef.isMainComponent()) {
if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
return structDef.emitError().append(
"The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
);
}
ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
// Verify the input parameter types are legal. The error message is explicit about what types
// are allowed so there is no benefit to report multiple errors if more than one parameter in
// the referenced function has an illegal type.
for (Type t : productParams) {
if (failed(checkMainFuncParamType(t, productFunc, false))) {
return failure();
if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
}
Expand Down Expand Up @@ -448,7 +442,15 @@ FuncDefOp StructDefOp::getConstrainOrProductFuncOp() {
return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
}

bool StructDefOp::isMainComponent() { return COMPONENT_NAME_MAIN == this->getSymName(); }
bool StructDefOp::isMainComponent() {
FailureOr<StructType> mainTypeOpt = getMainInstanceType(this->getOperation());
if (succeeded(mainTypeOpt)) {
if (StructType mainType = mainTypeOpt.value()) {
return structTypesUnify(mainType, this->getType());
}
}
return false;
}

//===------------------------------------------------------------------===//
// FieldDefOp
Expand Down
32 changes: 32 additions & 0 deletions lib/Util/SymbolHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,38 @@ FailureOr<SymbolRefAttr> getPathFromTopRoot(FuncDefOp &to, ModuleOp *foundRoot)
return RootPathBuilder(RootSelector::FURTHEST, to, foundRoot).getPathFromRootToFunc(to);
}

FailureOr<StructType> getMainInstanceType(Operation *lookupFrom) {
FailureOr<ModuleOp> rootOpt = getRootModule(lookupFrom);
if (failed(rootOpt)) {
return failure();
}
ModuleOp root = rootOpt.value();
if (Attribute a = root->getAttr(MAIN_ATTR_NAME)) {
// If the attribute is present, it must be a TypeAttr of StructType.
if (TypeAttr ta = llvm::dyn_cast<TypeAttr>(a)) {
return success(llvm::dyn_cast<StructType>(ta.getValue()));
}
return rootOpt->emitError().append(
'"', MAIN_ATTR_NAME, "\" on top-level module must be a '", StructType::name, "' attribute"
);
}
// The attribute is optional so it's okay if not present.
return success(nullptr);
}

FailureOr<SymbolLookupResult<StructDefOp>>
getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom) {
FailureOr<StructType> mainStructTypeOpt = getMainInstanceType(lookupFrom);
if (failed(mainStructTypeOpt)) {
return failure();
}
if (StructType st = mainStructTypeOpt.value()) {
return st.getDefinition(symbolTable, lookupFrom);
} else {
return success(nullptr);
}
}

LogicalResult verifyParamOfType(
SymbolTableCollection &tables, SymbolRefAttr param, Type parameterizedType, Operation *origin
) {
Expand Down
21 changes: 4 additions & 17 deletions test/Dialect/Struct/structs_fail.llzk
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ module attributes {veridise.lang = "llzk"} {
}
}
// -----
module attributes {veridise.lang = "llzk"} {
module attributes {llzk.main = !struct.type<@Main>, veridise.lang = "llzk"} {
struct.def @Main {
// expected-error@+1 {{"@Main" component "@compute" function parameters must be one of: {!struct.type<@Signal>, !array.type<.. x !struct.type<@Signal>>}}}
// expected-error@+1 {{main entry component "@compute" function parameters must be one of: {!struct.type<@Signal>, !array.type<.. x !struct.type<@Signal>>}}}
function.def @compute(%0: i1) -> !struct.type<@Main> {
%self = struct.new : !struct.type<@Main>
function.return %self : !struct.type<@Main>
Expand All @@ -325,32 +325,19 @@ module attributes {veridise.lang = "llzk"} {
}
}
// -----
module attributes {veridise.lang = "llzk"} {
module attributes {llzk.main = !struct.type<@Main>, veridise.lang = "llzk"} {
struct.def @Main {
function.def @compute() -> !struct.type<@Main> {
%self = struct.new : !struct.type<@Main>
function.return %self : !struct.type<@Main>
}
// expected-error@+1 {{"@Main" component "@constrain" function parameters must be one of: {!struct.type<@Main>, !struct.type<@Signal>, !array.type<.. x !struct.type<@Signal>>}}}
// expected-error@+1 {{main entry component "@constrain" function parameters must be one of: {!struct.type<@Main>, !struct.type<@Signal>, !array.type<.. x !struct.type<@Signal>>}}}
function.def @constrain(%self: !struct.type<@Main>, %0: i1) {
function.return
}
}
}
// -----
module attributes {veridise.lang = "llzk"} {
// expected-error@+1 {{The "@Main" component must have no parameters}}
struct.def @Main<[@T]> {
function.def @compute() -> !struct.type<@Main<[@T]>> {
%self = struct.new : !struct.type<@Main<[@T]>>
function.return %self : !struct.type<@Main<[@T]>>
}
function.def @constrain(%self: !struct.type<@Main<[@T]>>) {
function.return
}
}
}
// -----
module attributes {veridise.lang = "llzk"} {
struct.def @Signal {
struct.field @reg : !felt.type {llzk.pub}
Expand Down
Loading