Skip to content

Commit 12be8c0

Browse files
Generalize ColsRef macro
- allow using traits other than ShaConfig for const parameters - support fields types that derive AlignedBorrow (not recursed into) - support nested-array-type fields with literal array lengths - support nested-array-type fields with constant array lengths
1 parent 92dc96f commit 12be8c0

File tree

10 files changed

+282
-62
lines changed

10 files changed

+282
-62
lines changed

crates/circuits/sha-macros/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ proc-macro2 = "1.0"
1414

1515
[dev-dependencies]
1616
openvm-sha-air = { workspace = true }
17-
ndarray = "0.16"
17+
openvm-circuit-primitives-derive = { workspace = true }
18+
ndarray.workspace = true
1819

1920
[lib]
2021
proc-macro = true

crates/circuits/sha-macros/src/lib.rs

Lines changed: 169 additions & 61 deletions
Large diffs are not rendered by default.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use openvm_sha_air::{Sha256Config, ShaConfig};
2+
use openvm_sha_macros::ColsRef;
3+
4+
#[derive(ColsRef)]
5+
#[config(ShaConfig)]
6+
struct ArrayTest<T> {
7+
a: T,
8+
b: [T; 4],
9+
c: [[T; 4]; 4],
10+
}
11+
12+
#[test]
13+
fn arrays() {
14+
let input = [1; 1 + 4 + 4 * 4];
15+
let test: ArrayTestRef<u32> = ArrayTestRef::from::<Sha256Config>(&input);
16+
println!("{}", test.a);
17+
println!("{}", test.b);
18+
println!("{}", test.c);
19+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use openvm_sha_macros::ColsRef;
2+
3+
const ONE: usize = 1;
4+
const TWO: usize = 2;
5+
const THREE: usize = 3;
6+
7+
mod test_config;
8+
use test_config::{TestConfig, TestConfigImpl};
9+
10+
#[derive(ColsRef)]
11+
#[config(TestConfig)]
12+
struct ConstLenArrayTest<T, const N: usize> {
13+
a: T,
14+
b: [T; N],
15+
c: [[T; ONE]; TWO],
16+
d: [[[T; ONE]; TWO]; THREE],
17+
}
18+
19+
#[test]
20+
fn const_len_arrays() {
21+
let input = [1; 1 + TestConfigImpl::N * 2 + 1 * 2 * 3];
22+
let test: ConstLenArrayTestRef<u32> = ConstLenArrayTestRef::from::<TestConfigImpl>(&input);
23+
println!("{}", test.a);
24+
println!("{}", test.b);
25+
println!("{}", test.c);
26+
println!("{}", test.d);
27+
}

crates/circuits/sha-macros/tests/flags.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use openvm_sha_macros::ColsRef;
33

44
#[repr(C)]
55
#[derive(Clone, Copy, Debug, ColsRef)]
6+
#[config(ShaConfig)]
67
pub struct ShaFlagsCols<T, const ROW_VAR_CNT: usize> {
78
pub is_round_row: T,
89
/// A flag that indicates if the current row is among the first 4 rows of a block

crates/circuits/sha-macros/tests/nested.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ use openvm_sha_air::{Sha256Config, ShaConfig};
22
use openvm_sha_macros::ColsRef;
33

44
#[derive(ColsRef)]
5+
#[config(ShaConfig)]
56
struct Test1Cols<T, const WORD_BITS: usize> {
67
pub a: T,
78
pub nested: Test2Cols<T, WORD_BITS>,
89
}
910

1011
#[derive(ColsRef)]
12+
#[config(ShaConfig)]
1113
struct Test2Cols<T, const WORD_BITS: usize> {
1214
pub b: T,
1315
pub c: [T; WORD_BITS],
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use openvm_circuit_primitives_derive::AlignedBorrow;
2+
use openvm_sha_macros::ColsRef;
3+
4+
mod test_config;
5+
use test_config::{TestConfig, TestConfigImpl};
6+
7+
#[derive(ColsRef)]
8+
#[config(TestConfig)]
9+
struct TestCols<T, const N: usize> {
10+
a: [T; N],
11+
// Forces the field to be treated as a struct that derives AlignedBorrow.
12+
// In particular, ignores the fact that it ends with `Cols` and doesn't
13+
// expect a `PlainTestColsRef` type.
14+
#[plain]
15+
b: PlainCols<T>,
16+
}
17+
18+
#[derive(Clone, Copy, Debug, AlignedBorrow)]
19+
struct PlainCols<T> {
20+
a: T,
21+
b: [T; 4],
22+
}
23+
24+
#[test]
25+
fn plain() {
26+
let input = [1; TestConfigImpl::N + 1 + 4];
27+
let test: TestColsRef<u32> = TestColsRef::from::<TestConfigImpl>(&input);
28+
println!("{}", test.a);
29+
println!("{:?}", test.b);
30+
}
31+
32+
#[test]
33+
fn plain_mut() {
34+
let mut input = [1; TestConfigImpl::N + 1 + 4];
35+
let mut test: TestColsRefMut<u32> = TestColsRefMut::from::<TestConfigImpl>(&mut input);
36+
test.a[0] = 1;
37+
test.b.a = 1;
38+
test.b.b[0] = 1;
39+
println!("{}", test.a);
40+
println!("{:?}", test.b);
41+
}
42+
43+
#[test]
44+
fn plain_from_mut() {
45+
let mut input = [1; TestConfigImpl::N + 1 + 4];
46+
let mut test: TestColsRefMut<u32> = TestColsRefMut::from::<TestConfigImpl>(&mut input);
47+
test.a[0] = 1;
48+
test.b.a = 1;
49+
test.b.b[0] = 1;
50+
let test2: TestColsRef<u32> = TestColsRef::from_mut::<TestConfigImpl>(&mut test);
51+
println!("{}", test2.a);
52+
println!("{:?}", test2.b);
53+
}

crates/circuits/sha-macros/tests/simple.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use openvm_sha_air::{Sha256Config, ShaConfig};
22
use openvm_sha_macros::ColsRef;
33

44
#[derive(ColsRef)]
5+
#[config(ShaConfig)]
56
struct Test<T, const WORD_BITS: usize, const ROUNDS_PER_ROW: usize, const WORD_U16S: usize> {
67
a: T,
78
b: [T; WORD_BITS],
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pub trait TestConfig {
2+
const N: usize;
3+
}
4+
pub struct TestConfigImpl;
5+
impl TestConfig for TestConfigImpl {
6+
const N: usize = 4;
7+
}

crates/circuits/sha-macros/tests/work-vars.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use openvm_sha_macros::ColsRef;
33

44
#[repr(C)]
55
#[derive(Clone, Copy, Debug, ColsRef)]
6+
#[config(ShaConfig)]
67
pub struct ShaWorkVarsCols<
78
T,
89
const WORD_BITS: usize,

0 commit comments

Comments
 (0)