Skip to content

Commit 69a6cc1

Browse files
authored
Add llzk.fields attribute (#301)
1 parent 84bcd82 commit 69a6cc1

File tree

29 files changed

+1064
-154
lines changed

29 files changed

+1064
-154
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
added:
2+
- Add an optional `llzk.fields` attribute on modules for defining supported prime fields
3+
- Add optional field specifier on `felt.type`

include/llzk-c/Constants.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,15 @@ extern "C" {
2424
extern const char *LLZK_FUNC_NAME_COMPUTE;
2525
extern const char *LLZK_FUNC_NAME_CONSTRAIN;
2626

27-
/// Name of the attribute on the top-level ModuleOp that specifies the IR language name.
27+
/// Name of the attribute on the top-level ModuleOp that identifies the ModuleOp as the
28+
/// root module and specifies the frontend language name that the IR was compiled from, if
29+
/// available.
2830
extern const char *LLZK_LANG_ATTR_NAME;
2931

32+
/// Name of the attribute on the top-level ModuleOp that defines prime fields
33+
/// used in the circuit.
34+
extern const char *LLZK_FIELD_ATTR_NAME;
35+
3036
/// Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
3137
/// This attribute can appear zero or one times on the top-level ModuleOp and is associated with
3238
/// a `TypeAttr` specifying the `StructType` of the main struct.

include/llzk-c/Dialect/Felt.h

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,82 @@ extern "C" {
2828

2929
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Felt, llzk__felt);
3030

31-
/// Creates a llzk::felt::FeltConstAttr.
31+
//===----------------------------------------------------------------------===//
32+
// FeltConstAttr
33+
//===----------------------------------------------------------------------===//
34+
35+
/// Creates a llzk::felt::FeltConstAttr with an unspecified field.
3236
MLIR_CAPI_EXPORTED MlirAttribute llzkFeltConstAttrGet(MlirContext context, int64_t value);
3337

34-
/// Creates a llzk::felt::FeltConstAttr with a set bit length.
38+
/// Creates a llzk::felt::FeltConstAttr with a specified field.
39+
MLIR_CAPI_EXPORTED MlirAttribute
40+
llzkFeltConstAttrGetWithField(MlirContext context, int64_t value, MlirStringRef fieldName);
41+
42+
/// Creates a llzk::felt::FeltConstAttr with a set bit length in an unspecified field.
3543
MLIR_CAPI_EXPORTED MlirAttribute
3644
llzkFeltConstAttrGetWithBits(MlirContext ctx, unsigned numBits, int64_t value);
3745

46+
/// Creates a llzk::felt::FeltConstAttr with a set bit length in a specified field.
47+
MLIR_CAPI_EXPORTED MlirAttribute llzkFeltConstAttrGetWithBitsWithField(
48+
MlirContext ctx, unsigned numBits, int64_t value, MlirStringRef fieldName
49+
);
50+
3851
/// Creates a llzk::felt::FeltConstAttr from a base-10 representation of a number.
52+
/// in an unspecified field.
3953
MLIR_CAPI_EXPORTED MlirAttribute
4054
llzkFeltConstAttrGetFromString(MlirContext context, unsigned numBits, MlirStringRef str);
4155

42-
/// Creates a llzk::felt::FeltConstAttr from an array of big-integer parts in LSB order.
56+
/// Creates a llzk::felt::FeltConstAttr from a base-10 representation of a number.
57+
/// in a specified field.
58+
MLIR_CAPI_EXPORTED MlirAttribute llzkFeltConstAttrGetFromStringWithField(
59+
MlirContext context, unsigned numBits, MlirStringRef str, MlirStringRef fieldName
60+
);
61+
62+
/// Creates a llzk::felt::FeltConstAttr from an array of big-integer parts in LSB order
63+
/// in an unspecified field.
4364
MLIR_CAPI_EXPORTED MlirAttribute llzkFeltConstAttrGetFromParts(
4465
MlirContext context, unsigned numBits, const uint64_t *parts, intptr_t nParts
4566
);
4667

68+
/// Creates a llzk::felt::FeltConstAttr from an array of big-integer parts in LSB order
69+
/// in a specified field.
70+
MLIR_CAPI_EXPORTED MlirAttribute llzkFeltConstAttrGetFromPartsWithField(
71+
MlirContext context, unsigned numBits, const uint64_t *parts, intptr_t nParts,
72+
MlirStringRef fieldName
73+
);
74+
4775
/// Returns true if the attribute is a FeltConstAttr.
4876
LLZK_DECLARE_ATTR_ISA(FeltConstAttr);
4977

50-
/// Creates a llzk::felt::FeltType.
78+
/// Get the underlying felt type of the FeltConstAttr.
79+
MLIR_CAPI_EXPORTED MlirType llzkFeltConstAttrGetType(MlirAttribute attr);
80+
81+
//===----------------------------------------------------------------------===//
82+
// FieldSpecAttr
83+
//===----------------------------------------------------------------------===//
84+
85+
/// Creates a llzk::felt::FieldSpecAttr from a base-10 representation of the prime.
86+
MLIR_CAPI_EXPORTED MlirAttribute llzkFieldSpecAttrGetFromString(
87+
MlirContext context, MlirStringRef fieldName, unsigned numBits, MlirStringRef primeStr
88+
);
89+
90+
/// Creates a llzk::felt::FieldSpecAttr from an array of big-integer parts in LSB order representing
91+
/// the prime.
92+
MLIR_CAPI_EXPORTED MlirAttribute llzkFieldSpecAttrGetFromParts(
93+
MlirContext context, MlirStringRef fieldName, unsigned numBits, const uint64_t *parts,
94+
intptr_t nParts
95+
);
96+
97+
//===----------------------------------------------------------------------===//
98+
// FeltType
99+
//===----------------------------------------------------------------------===//
100+
101+
/// Creates a llzk::felt::FeltType with an unspecified field.
51102
MLIR_CAPI_EXPORTED MlirType llzkFeltTypeGet(MlirContext context);
52103

104+
/// Creates a llzk::felt::FeltType in a given field.
105+
MLIR_CAPI_EXPORTED MlirType llzkFeltTypeGetWithField(MlirContext context, MlirStringRef fieldName);
106+
53107
/// Returns true if the type is a FeltType.
54108
LLZK_DECLARE_TYPE_ISA(FeltType);
55109

include/llzk/Analysis/IntervalAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "llzk/Analysis/AnalysisWrappers.h"
1414
#include "llzk/Analysis/ConstraintDependencyGraph.h"
1515
#include "llzk/Analysis/DenseAnalysis.h"
16-
#include "llzk/Analysis/Field.h"
1716
#include "llzk/Analysis/Intervals.h"
1817
#include "llzk/Analysis/SparseAnalysis.h"
1918
#include "llzk/Dialect/Array/IR/Ops.h"
@@ -25,6 +24,7 @@
2524
#include "llzk/Dialect/Global/IR/Ops.h"
2625
#include "llzk/Dialect/Polymorphic/IR/Ops.h"
2726
#include "llzk/Util/Compare.h"
27+
#include "llzk/Util/Field.h"
2828

2929
#include <mlir/IR/BuiltinOps.h>
3030
#include <mlir/Pass/AnalysisManager.h>

include/llzk/Analysis/Intervals.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#pragma once
1111

12-
#include "llzk/Analysis/Field.h"
12+
#include "llzk/Util/Field.h"
1313

1414
#include <mlir/Support/LogicalResult.h>
1515

include/llzk/Dialect/Bool/IR/Ops.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def LLZK_AssertOp
115115
}
116116

117117
// Match format of Index comparisons (for now)
118-
def LLZK_CmpOp : BoolDialectOp<"cmp", [Pure]> {
118+
def LLZK_CmpOp
119+
: NaryOpBase<BoolDialect, "cmp",
120+
LLZK_FeltType.builderCall, [Pure, TypesUnify<"lhs", "rhs">]> {
119121
let summary = "compare field element values";
120122
let description = [{
121123
This operation takes two field element values and compares them according to the
@@ -145,13 +147,17 @@ def LLZK_CmpOp : BoolDialectOp<"cmp", [Pure]> {
145147

146148
// Not equal comparison.
147149
%2 = bool.cmp ne(%a, %b)
150+
151+
// Not equal comparison for felts in a specified field
152+
%3 = bool.cmp ne(%c, %d) : !felt.type<"babybear">, !felt.type<"babybear">
148153
```
149154
}];
150155

151156
let arguments = (ins LLZK_CmpPredicateAttr:$predicate, LLZK_FeltType:$lhs,
152157
LLZK_FeltType:$rhs);
153158
let results = (outs I1:$result);
154-
let assemblyFormat = [{ `` $predicate `(` $lhs `,` $rhs `)` attr-dict }];
159+
let assemblyFormat =
160+
[{ `` $predicate `(` $lhs `,` $rhs `)` `` custom<InferredOrParsedType>(type($lhs), "true") `` custom<InferredOrParsedType>(type($rhs), "false") attr-dict }];
155161
}
156162

157163
#endif // LLZK_BOOLEAN_OPS

include/llzk/Dialect/Cast/IR/Ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "llzk/Dialect/Felt/IR/Types.h"
1414
#include "llzk/Dialect/Function/IR/OpTraits.h"
1515

16+
// MLIR interfaces used by generated ops
17+
#include <mlir/Interfaces/InferTypeOpInterface.h>
18+
1619
// Include TableGen'd declarations
1720
#define GET_OP_CLASSES
1821
#include "llzk/Dialect/Cast/IR/Ops.h.inc"

include/llzk/Dialect/Cast/IR/Ops.td

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
include "llzk/Dialect/Shared/Types.td"
1414
include "llzk/Dialect/Cast/IR/Dialect.td"
1515
include "llzk/Dialect/Felt/IR/Types.td"
16+
include "llzk/Dialect/Shared/OpsBase.td"
17+
include "mlir/Interfaces/InferTypeOpInterface.td"
1618
include "llzk/Dialect/Function/IR/OpTraits.td"
1719

1820
include "mlir/IR/OpBase.td"
1921
include "mlir/Interfaces/SideEffectInterfaces.td"
2022
include "mlir/IR/SymbolInterfaces.td"
2123

22-
def LLZK_IntToFeltOp : Op<CastDialect, "tofelt", [Pure]> {
24+
class CastOp<string mnemonic, list<Trait> traits = []>
25+
: NaryOpBase<CastDialect, mnemonic, LLZK_FeltType.builderCall, traits>;
26+
27+
def LLZK_IntToFeltOp : CastOp<"tofelt", [Pure]> {
2328
let summary = "convert an integer into a field element";
2429
let description = [{
2530
This operation converts a supported integer type value into a field element value.
@@ -34,10 +39,11 @@ def LLZK_IntToFeltOp : Op<CastDialect, "tofelt", [Pure]> {
3439

3540
let arguments = (ins AnyLLZKIntType:$value);
3641
let results = (outs LLZK_FeltType:$result);
37-
let assemblyFormat = [{ $value `:` type($value) attr-dict }];
42+
let assemblyFormat =
43+
[{ $value `:` type($value) `` custom<InferredOrParsedType>(type($result), "false") attr-dict }];
3844
}
3945

40-
def LLZK_FeltToIndexOp : Op<CastDialect, "toindex", [Pure, NotFieldNative]> {
46+
def LLZK_FeltToIndexOp : CastOp<"toindex", [Pure, NotFieldNative]> {
4147
let summary = "convert a field element into an index";
4248
let description = [{
4349
This operation converts a field element value into an index value to allow use
@@ -52,7 +58,8 @@ def LLZK_FeltToIndexOp : Op<CastDialect, "toindex", [Pure, NotFieldNative]> {
5258

5359
let arguments = (ins LLZK_FeltType:$value);
5460
let results = (outs Index:$result);
55-
let assemblyFormat = [{ $value attr-dict }];
61+
let assemblyFormat =
62+
[{ $value `` custom<InferredOrParsedType>(type($value), "true") attr-dict }];
5663
}
5764

5865
#endif // LLZK_CAST_OPS

include/llzk/Dialect/Felt/IR/Attrs.td

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,85 @@ def LLZK_FeltConstAttr
2424
A felt attribute represents a finite field element.
2525
}];
2626

27-
let parameters = (ins APIntParameter<"The felt constant value">:$value);
27+
let parameters = (ins APIntParameter<"The felt constant value">:$value,
28+
OptionalParameter<"::mlir::StringAttr">:$fieldName);
2829

29-
let returnType = "::llvm::APInt";
30-
let convertFromStorage = "$_self.getValue()";
31-
32-
let assemblyFormat = [{ $value }];
30+
let assemblyFormat = [{ $value ( ` ` `<` $fieldName^ `>` )? }];
3331

3432
let builders =
35-
[AttrBuilder<(ins "unsigned":$numBits, "::llvm::StringRef":$str), [{
36-
return $_get(context, ::llvm::APInt(numBits, str, 10));
37-
}]>,
33+
[AttrBuilder<(ins "unsigned":$numBits, "::llvm::StringRef":$str,
34+
"::llvm::StringRef":$fieldName),
35+
[{
36+
return $_get(context, ::llvm::APInt(numBits, str, 10), ::mlir::StringAttr::get(context, fieldName));
37+
}]>,
38+
AttrBuilder<(ins "unsigned":$numBits,
39+
"::llvm::ArrayRef<uint64_t>":$parts,
40+
"::llvm::StringRef":$fieldName),
41+
[{
42+
return $_get(context, ::llvm::APInt(numBits, parts), ::mlir::StringAttr::get(context, fieldName));
43+
}]>,
44+
AttrBuilder<(ins "unsigned":$numBits, "::llvm::StringRef":$str), [{
45+
return $_get(context, ::llvm::APInt(numBits, str, 10), ::mlir::StringAttr());
46+
}]>,
3847
AttrBuilder<
3948
(ins "unsigned":$numBits, "::llvm::ArrayRef<uint64_t>":$parts), [{
40-
return $_get(context, ::llvm::APInt(numBits, parts));
41-
}]>];
49+
return $_get(context, ::llvm::APInt(numBits, parts), ::mlir::StringAttr());
50+
}]>,
51+
AttrBuilder<(ins "::llvm::APInt":$value, "::llvm::StringRef":$fieldName),
52+
[{
53+
return $_get(context, value, ::mlir::StringAttr::get(context, fieldName));
54+
}]>,
55+
AttrBuilder<(ins "::llvm::APInt":$value), [{
56+
return $_get(context, value, ::mlir::StringAttr());
57+
}]>];
4258

4359
let extraClassDeclaration = [{
4460
::mlir::Type getType() const;
61+
62+
operator ::llvm::APInt() const;
63+
}];
64+
65+
let genVerifyDecl = 1;
66+
}
67+
68+
def LLZK_FieldSpecAttr : AttrDef<FeltDialect, "FieldSpec"> {
69+
let mnemonic = "field";
70+
let summary = "prime field specification";
71+
let description = [{
72+
A specification of a prime field for use by felt types.
73+
74+
These specifications are provided in the `llzk.fields` attribute on the root
75+
module as either a single element or a flat array, for example:
76+
77+
module attributes {llzk.lang, llzk.fields = field<foo, 7> { ... }
78+
module attributes {llzk.lang, llzk.fields = [field<>]} { ... }
79+
80+
Specifications should not be provided for built-in fields, which include:
81+
- babybear
82+
- bn128/bn254
83+
- goldilocks
84+
- koalabear
85+
- mersenne31
4586
}];
87+
88+
let parameters = (ins "::mlir::StringAttr":$fieldName,
89+
APIntParameter<"The prime modulus">:$prime);
90+
91+
// Format is [{ `<` $fieldName `,` $prime `>` }], but with custom parsing
92+
// to enable caching of the Field object definition.
93+
let hasCustomAssemblyFormat = 1;
94+
95+
let builders =
96+
[AttrBuilder<(ins "::llvm::StringRef":$fieldName, "unsigned":$numBits,
97+
"::llvm::StringRef":$primeStr),
98+
[{
99+
return $_get(context, ::mlir::StringAttr::get(context, fieldName), ::llvm::APInt(numBits, primeStr, 10));
100+
}]>,
101+
AttrBuilder<(ins "::llvm::StringRef":$fieldName, "unsigned":$numBits,
102+
"::llvm::ArrayRef<uint64_t>":$parts),
103+
[{
104+
return $_get(context, ::mlir::StringAttr::get(context, fieldName), ::llvm::APInt(numBits, parts));
105+
}]>];
46106
}
47107

48108
#endif // LLZK_FELT_ATTRS

0 commit comments

Comments
 (0)