Skip to content

Commit 68630ef

Browse files
Add ColsRef macro
1 parent 3c80007 commit 68630ef

File tree

6 files changed

+1225
-2
lines changed

6 files changed

+1225
-2
lines changed

crates/circuits/primitives/derive/Cargo.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ license.workspace = true
1212
proc-macro = true
1313

1414
[dependencies]
15-
syn = { version = "2.0", features = ["parsing"] }
15+
syn = { version = "2.0", features = ["parsing", "extra-traits"] }
1616
quote = "1.0"
17-
itertools = { workspace = true }
17+
itertools = { workspace = true, default-features = true }
18+
proc-macro2 = "1.0"
19+
20+
[dev-dependencies]
21+
ndarray.workspace = true
22+
23+
[package.metadata.cargo-shear]
24+
ignored = ["ndarray"]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# ColsRef macro
2+
3+
The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes.
4+
5+
Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384).
6+
See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits.
7+
8+
## Overview
9+
10+
As an illustrative example, consider the following columns struct:
11+
```rust
12+
struct ExampleCols<T, const N: usize> {
13+
arr: [T; N],
14+
sum: T,
15+
}
16+
```
17+
Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10.
18+
We can define a trait that stores the config parameters.
19+
```rust
20+
pub trait ExampleConfig {
21+
const N: usize;
22+
}
23+
```
24+
and then implement it for the two different configs.
25+
```rust
26+
pub struct ExampleConfigImplA;
27+
impl ExampleConfig for ExampleConfigImplA {
28+
const N: usize = 5;
29+
}
30+
pub struct ExampleConfigImplB;
31+
impl ExampleConfig for ExampleConfigImplB {
32+
const N: usize = 10;
33+
}
34+
```
35+
Then we can use the `ColsRef` macro like this
36+
```rust
37+
#[derive(ColsRef)]
38+
#[config(ExampleConfig)]
39+
struct ExampleCols<T, const N: usize> {
40+
arr: [T; N],
41+
sum: T,
42+
}
43+
```
44+
which will generate a columns struct that uses references to the fields.
45+
```rust
46+
struct ExampleColsRef<'a, T, const N: usize> {
47+
arr: ndarray::ArrayView1<'a, T>, // an n-dimensional view into the input slice (ArrayView2 for 2D arrays, etc.)
48+
sum: &'a T,
49+
}
50+
```
51+
The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct.
52+
The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct.
53+
54+
So, the constraint generation code can be written as
55+
```rust
56+
impl<AB: InteractionBuilder, C: ExampleConfig> Air<AB> for ExampleAir<C> {
57+
fn eval(&self, builder: &mut AB) {
58+
let main = builder.main();
59+
let (local, _) = (main.row_slice(0), main.row_slice(1));
60+
let local_cols = ExampleColsRef::<AB::Var>::from::<C>(&local[..C::N + 1]);
61+
let sum = local_cols.arr.iter().sum();
62+
builder.assert_eq(local_cols.sum, sum);
63+
}
64+
}
65+
```
66+
Notes:
67+
- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice.
68+
- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait.
69+
70+
The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation.
71+
72+
The `ColsRef` macro supports more than just variable-length array fields.
73+
The field types can also be:
74+
- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]`
75+
- any type that derives `ColsRef` via `#[derive(ColsRef)]`
76+
- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow`
77+
78+
Note that we currently do not support arrays of types that derive `ColsRef`.
79+
80+
## Specification
81+
82+
Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`.
83+
- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait
84+
85+
The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows:
86+
- type `T` becomes `&T`
87+
- type `[T; LEN]` becomes `&ArrayView1<T>` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig`
88+
- the `ExampleColsRef::from` method will correctly infer the length of the array from the config
89+
- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively
90+
- one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type
91+
- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type
92+
- if a field whose name ends in `Cols` is annotated with `#[aligned_borrow]`, then the aligned borrow takes precedence, and the field is not transformed into an `ArrayView`
93+
- nested arrays of `U` become `&ArrayViewX<U>` where `X` is the number of dimensions in the nested array type
94+
- `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]`
95+
- the `ArrayViewX` type provides a `X`-dimensional view into the row slice
96+
97+
The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references.
98+
- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields.
99+
- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively.
100+
101+
Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented:
102+
```rust
103+
// Takes a slice of the correct length and returns an instance of the columns struct.
104+
pub const fn from<C: ExampleConfig>(slice: &[T]) -> Self;
105+
// Returns the number of cells in the struct
106+
pub const fn width<C: ExampleConfig>() -> usize;
107+
```
108+
Note that the `width` method on both structs returns the same value.
109+
110+
Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`.
111+
This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`.
112+
113+
See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types.

0 commit comments

Comments
 (0)