Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelogs/unreleased/th__specify_main_via_attr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
changed:
- The main struct is now specified via the "llzk.main" attribute on the top-level module, rather than being hardcoded to a struct named "Main".
12 changes: 5 additions & 7 deletions include/llzk-c/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ extern "C" {
/// to a circom signal or AIR/PLONK column, opposed to intermediate values or other expressions.
extern const char *LLZK_COMPONENT_NAME_SIGNAL;

/// Symbol name for the main entry point struct/component (if any). There are additional
/// restrictions on the struct with this name:
/// 1. It cannot have struct parameters.
/// 2. The parameter types of its functions (besides the required "self" parameter) can
/// only be `struct<Signal>` or `array<.. x struct<Signal>>`.
extern const char *LLZK_COMPONENT_NAME_MAIN;

/// Symbol name for the witness generation (and resp. constraint generation) functions within a
/// component.
extern const char *LLZK_FUNC_NAME_COMPUTE;
Expand All @@ -38,6 +31,11 @@ extern const char *LLZK_FUNC_NAME_CONSTRAIN;
/// Name of the attribute on the top-level ModuleOp that specifies the IR language name.
extern const char *LLZK_LANG_ATTR_NAME;

/// Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
/// This attribute can appear zero or one times on the top-level ModuleOp and is associated with
/// a `TypeAttr` specifying the `StructType` of the main struct.
extern const char *LLZK_MAIN_ATTR_NAME;

#ifdef __cplusplus
}
#endif
Expand Down
6 changes: 5 additions & 1 deletion include/llzk-c/Typing.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ MLIR_CAPI_EXPORTED bool llzkIsValidArrayElemType(MlirType type);
/// Checks if the type is a LLZK Array and it also contains a valid LLZK type.
MLIR_CAPI_EXPORTED bool llzkIsValidArrayType(MlirType type);

/// Return `false` iff the type contains any `TypeVarType`
/// Return `false` if the type contains any of the following:
/// - `TypeVarType`
/// - `SymbolRefAttr`
/// - `AffineMapAttr`
/// - `StructType` with parameters if `allowStructParams==false`
MLIR_CAPI_EXPORTED bool llzkIsConcreteType(MlirType type, bool allowStructParams);

/// Return `true` iff the given type is a StructType referencing the `COMPONENT_NAME_SIGNAL`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def StructCleanupMode
// use-def chain from some concrete struct are deleted.
I32EnumAttrCase<"ConcreteAsRoot", 2, "concrete-as-root">,
// MainAsRoot: All structs that cannot be reached by a
// use-def chain from the "Main" struct are deleted.
// use-def chain from the main struct are deleted.
I32EnumAttrCase<"MainAsRoot", 3, "main-as-root">,
]> {
let cppNamespace = "::llzk::polymorphic";
Expand Down
2 changes: 1 addition & 1 deletion include/llzk/Dialect/Struct/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def LLZK_StructDefOp
/// any surrounding module scopes.
::mlir::SymbolRefAttr getFullyQualifiedName();

/// Return `true` iff this StructDefOp is named "Main".
/// Return `true` iff this StructDefOp is the main struct. See `llzk::MAIN_ATTR_NAME`.
bool isMainComponent();
}];

Expand Down
14 changes: 7 additions & 7 deletions include/llzk/Util/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@ namespace llzk {

/// Symbol name for the struct/component representing a signal. A "signal" has direct correspondence
/// to a circom signal or AIR/PLONK column, opposed to intermediate values or other expressions.
///
/// DEPRECATED: to be removed in favor of plain `llzk::felt::FeltType`.
constexpr char COMPONENT_NAME_SIGNAL[] = "Signal";

/// Symbol name for the main entry point struct/component (if any). There are additional
/// restrictions on the struct with this name:
/// 1. It cannot have struct parameters.
/// 2. The parameter types of its functions (besides the required "self" parameter) can
/// only be `struct<Signal>` or `array<.. x struct<Signal>>`.
constexpr char COMPONENT_NAME_MAIN[] = "Main";

/// Symbol name for the witness generation (and resp. constraint generation) functions within a
/// component.
constexpr char FUNC_NAME_COMPUTE[] = "compute";
Expand All @@ -31,4 +26,9 @@ constexpr char FUNC_NAME_PRODUCT[] = "product";
/// Name of the attribute on the top-level ModuleOp that specifies the IR language name.
constexpr char LANG_ATTR_NAME[] = "veridise.lang";

/// Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
/// This attribute can appear zero or one times on the top-level ModuleOp and is associated with
/// a `TypeAttr` specifying the `StructType` of the main struct.
constexpr char MAIN_ATTR_NAME[] = "llzk.main";

} // namespace llzk
13 changes: 13 additions & 0 deletions include/llzk/Util/SymbolHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ getPathFromTopRoot(component::FieldDefOp &to, mlir::ModuleOp *foundRoot = nullpt
mlir::FailureOr<mlir::SymbolRefAttr>
getPathFromTopRoot(function::FuncDefOp &to, mlir::ModuleOp *foundRoot = nullptr);

/// @brief Lookup the `StructType` of the main instance.
///
/// This is specified by a `TypeAttr` on the top-level module with the key `LLZK_MAIN_ATTR_NAME`
/// and is optional, in which case the result will be `success(nullptr)`.
mlir::FailureOr<llzk::component::StructType> getMainInstanceType(mlir::Operation *lookupFrom);

/// @brief Lookup the `StructDefOp` of the main instance.
///
/// This is specified by a `TypeAttr` on the top-level module with the key `LLZK_MAIN_ATTR_NAME`
/// and is optional, in which case the result will be `success(nullptr)`.
mlir::FailureOr<SymbolLookupResult<llzk::component::StructDefOp>>
getMainInstanceDef(mlir::SymbolTableCollection &symbolTable, mlir::Operation *lookupFrom);

/// @brief Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers
/// @tparam T the type of symbol being resolved (e.g., function::FuncDefOp)
/// @param symbolTable
Expand Down
6 changes: 5 additions & 1 deletion include/llzk/Util/TypeHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ bool isValidArrayElemType(mlir::Type type);
/// Checks if the type is a LLZK Array and it also contains a valid LLZK type.
bool isValidArrayType(mlir::Type type);

/// Return `false` iff the type contains any `TypeVarType`
/// Return `false` if the type contains any of the following:
/// - `TypeVarType`
/// - `SymbolRefAttr`
/// - `AffineMapAttr`
/// - `StructType` with parameters if `allowStructParams==false`
bool isConcreteType(mlir::Type type, bool allowStructParams = true);

inline mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type) {
Expand Down
2 changes: 1 addition & 1 deletion lib/CAPI/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "llzk-c/Constants.h"

const char *LLZK_COMPONENT_NAME_SIGNAL = llzk::COMPONENT_NAME_SIGNAL;
const char *LLZK_COMPONENT_NAME_MAIN = llzk::COMPONENT_NAME_MAIN;
const char *LLZK_FUNC_NAME_COMPUTE = llzk::FUNC_NAME_COMPUTE;
const char *LLZK_FUNC_NAME_CONSTRAIN = llzk::FUNC_NAME_CONSTRAIN;
const char *LLZK_LANG_ATTR_NAME = llzk::LANG_ATTR_NAME;
const char *LLZK_MAIN_ATTR_NAME = llzk::MAIN_ATTR_NAME;
20 changes: 12 additions & 8 deletions lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,7 @@ class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<Flatte
}

inline LogicalResult runOn(ModuleOp modOp) {
// If the cleanup mode is set to remove anything not reachable from the "Main" struct, do an
// If the cleanup mode is set to remove anything not reachable from the main struct, do an
// initial pass to remove things that are not reachable (as an optimization) because creating
// an instantiated version of a struct will not cause something to become reachable that was
// not already reachable in parameterized form.
Expand Down Expand Up @@ -1895,22 +1895,26 @@ class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<Flatte
Step5_Cleanup::FromKeepSet cleaner(
rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
);
StructDefOp main =
cleaner.tables.getSymbolTable(rootMod).lookup<StructDefOp>(COMPONENT_NAME_MAIN);
FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
getMainInstanceDef(cleaner.tables, rootMod.getOperation());
if (failed(mainOpt)) {
return failure();
}
SymbolLookupResult<StructDefOp> main = mainOpt.value();
if (emitWarning && !main) {
// Emit warning if there is no "Main" because all structs may be removed (only structs that
// are reachable from a global def or free function will be preserved since those constructs
// are not candidate for removal in this pass).
// Emit warning if there is no main specified because all structs may be removed (only
// structs that are reachable from a global def or free function will be preserved since
// those constructs are not candidate for removal in this pass).
rootMod.emitWarning()
.append(
"using option '", cleanupMode.getArgStr(), '=',
stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
COMPONENT_NAME_MAIN, "\" struct may remove all structs!"
MAIN_ATTR_NAME, "\" attribute on the top-level module may remove all structs!"
)
.report();
}
return cleaner.eraseUnreachableFrom(
main ? ArrayRef<StructDefOp> {main} : ArrayRef<StructDefOp> {}
main ? ArrayRef<StructDefOp> {*main} : ArrayRef<StructDefOp> {}
);
}
};
Expand Down
57 changes: 30 additions & 27 deletions lib/Dialect/Struct/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "llzk/Dialect/Array/IR/Types.h"
#include "llzk/Dialect/Felt/IR/Types.h"
#include "llzk/Dialect/Function/IR/Ops.h"
#include "llzk/Dialect/LLZK/IR/AttributeHelper.h"
#include "llzk/Dialect/Struct/IR/Ops.h"
Expand All @@ -33,6 +34,7 @@
#include "llzk/Dialect/Struct/IR/Ops.cpp.inc"

using namespace mlir;
using namespace llzk::felt;
using namespace llzk::array;
using namespace llzk::function;

Expand Down Expand Up @@ -218,24 +220,24 @@ LogicalResult StructDefOp::verifySymbolUses(SymbolTableCollection &tables) {

namespace {

inline LogicalResult checkMainFuncParamType(Type pType, FuncDefOp inFunc, bool appendSelf) {
if (isSignalType(pType)) {
inline LogicalResult
checkMainFuncParamType(Type pType, FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
if (llvm::isa<FeltType>(pType)) {
return success();
} else if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
if (isSignalType(arrayParamTy.getElementType())) {
if (llvm::isa<FeltType>(arrayParamTy.getElementType())) {
return success();
}
}

std::string message = buildStringViaCallback([&inFunc, appendSelf](llvm::raw_ostream &ss) {
ss << "\"@" << COMPONENT_NAME_MAIN << "\" component \"@" << inFunc.getSymName()
std::string message = buildStringViaCallback([&inFunc, appendSelfType](llvm::raw_ostream &ss) {
ss << "main entry component \"@" << inFunc.getSymName()
<< "\" function parameters must be one of: {";
if (appendSelf) {
ss << "!" << StructType::name << "<@" << COMPONENT_NAME_MAIN << ">, ";
if (appendSelfType.has_value()) {
ss << appendSelfType.value() << ", ";
}
ss << "!" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL << ">, ";
ss << "!" << ArrayType::name << "<.. x !" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL
<< ">>}";
ss << '!' << FeltType::name << ", ";
ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
});
return inFunc.emitError(message);
}
Expand All @@ -254,22 +256,17 @@ inline LogicalResult verifyStructComputeConstrain(
ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
if (structDef.isMainComponent()) {
// Verify that the Struct has no parameters.
if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
return structDef.emitError().append(
"The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
);
}
// Verify the input parameter types are legal. The error message is explicit about what types
// are allowed so there is no benefit to report multiple errors if more than one parameter in
// the referenced function has an illegal type.
for (Type t : computeParams) {
if (failed(checkMainFuncParamType(t, computeFunc, false))) {
if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
auto appendSelf = std::make_optional(structDef.getType());
for (Type t : constrainParams) {
if (failed(checkMainFuncParamType(t, constrainFunc, true))) {
if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
Expand All @@ -295,16 +292,14 @@ inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp produc
assert(productFunc.hasAllowWitnessAttr());

// Verify parameter types are valid
ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
if (structDef.isMainComponent()) {
if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
return structDef.emitError().append(
"The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
);
}
ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
// Verify the input parameter types are legal. The error message is explicit about what types
// are allowed so there is no benefit to report multiple errors if more than one parameter in
// the referenced function has an illegal type.
for (Type t : productParams) {
if (failed(checkMainFuncParamType(t, productFunc, false))) {
return failure();
if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
return failure(); // checkMainFuncParamType() already emits a sufficient error message
}
}
}
Expand Down Expand Up @@ -448,7 +443,15 @@ FuncDefOp StructDefOp::getConstrainOrProductFuncOp() {
return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
}

bool StructDefOp::isMainComponent() { return COMPONENT_NAME_MAIN == this->getSymName(); }
bool StructDefOp::isMainComponent() {
FailureOr<StructType> mainTypeOpt = getMainInstanceType(this->getOperation());
if (succeeded(mainTypeOpt)) {
if (StructType mainType = mainTypeOpt.value()) {
return structTypesUnify(mainType, this->getType());
}
}
return false;
}

//===------------------------------------------------------------------===//
// FieldDefOp
Expand Down
37 changes: 37 additions & 0 deletions lib/Util/SymbolHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,43 @@ FailureOr<SymbolRefAttr> getPathFromTopRoot(FuncDefOp &to, ModuleOp *foundRoot)
return RootPathBuilder(RootSelector::FURTHEST, to, foundRoot).getPathFromRootToFunc(to);
}

FailureOr<StructType> getMainInstanceType(Operation *lookupFrom) {
FailureOr<ModuleOp> rootOpt = getRootModule(lookupFrom);
if (failed(rootOpt)) {
return failure();
}
ModuleOp root = rootOpt.value();
if (Attribute a = root->getAttr(MAIN_ATTR_NAME)) {
// If the attribute is present, it must be a TypeAttr of concrete StructType.
if (TypeAttr ta = llvm::dyn_cast<TypeAttr>(a)) {
if (StructType st = llvm::dyn_cast<StructType>(ta.getValue())) {
if (isConcreteType(st)) {
return success(st);
}
}
}
return rootOpt->emitError().append(
'"', MAIN_ATTR_NAME, "\" on top-level module must be a concrete '", StructType::name,
"' attribute. Found: ", a
);
}
// The attribute is optional so it's okay if not present.
return success(nullptr);
}

FailureOr<SymbolLookupResult<StructDefOp>>
getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom) {
FailureOr<StructType> mainStructTypeOpt = getMainInstanceType(lookupFrom);
if (failed(mainStructTypeOpt)) {
return failure();
}
if (StructType st = mainStructTypeOpt.value()) {
return st.getDefinition(symbolTable, lookupFrom);
} else {
return success(nullptr);
}
}

LogicalResult verifyParamOfType(
SymbolTableCollection &tables, SymbolRefAttr param, Type parameterizedType, Operation *origin
) {
Expand Down
Loading