Skip to content

Commit 9c1dbbd

Browse files
Updated sha256 air to use sha-macros (ColsRef structs)
1 parent 1c55a8b commit 9c1dbbd

File tree

7 files changed

+1030
-642
lines changed

7 files changed

+1030
-642
lines changed

crates/circuits/sha256-air/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ openvm-circuit-primitives = { workspace = true }
99
openvm-stark-backend = { workspace = true }
1010
sha2 = { version = "0.10", features = ["compress"] }
1111
rand.workspace = true
12+
openvm-sha-macros = { workspace = true }
13+
ndarray = "0.16"
1214

1315
[dev-dependencies]
1416
openvm-stark-sdk = { workspace = true }

crates/circuits/sha256-air/src/air.rs

Lines changed: 259 additions & 196 deletions
Large diffs are not rendered by default.
Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
//! WARNING: the order of fields in the structs is important, do not change it
22
3-
use openvm_circuit_primitives::{utils::not, AlignedBorrow};
3+
use openvm_circuit_primitives::utils::not;
4+
use openvm_sha_macros::ColsRef;
45
use openvm_stark_backend::p3_field::FieldAlgebra;
56

6-
use super::{
7-
SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS,
8-
SHA256_WORD_U16S, SHA256_WORD_U8S,
9-
};
7+
use crate::ShaConfig;
108

119
/// In each SHA256 block:
1210
/// - First 16 rows use Sha256RoundCols
@@ -21,82 +19,119 @@ use super::{
2119
/// 1. Common constraints to work on either struct type by accessing these shared fields
2220
/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional constraints
2321
#[repr(C)]
24-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
25-
pub struct Sha256RoundCols<T> {
26-
pub flags: Sha256FlagsCols<T>,
27-
pub work_vars: Sha256WorkVarsCols<T>,
28-
pub schedule_helper: Sha256MessageHelperCols<T>,
29-
pub message_schedule: Sha256MessageScheduleCols<T>,
22+
#[derive(Clone, Copy, Debug, ColsRef)]
23+
pub struct ShaRoundCols<
24+
T,
25+
const WORD_BITS: usize,
26+
const WORD_U8S: usize,
27+
const WORD_U16S: usize,
28+
const ROUNDS_PER_ROW: usize,
29+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
30+
const ROW_VAR_CNT: usize,
31+
> {
32+
pub flags: ShaFlagsCols<T, ROW_VAR_CNT>,
33+
pub work_vars: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
34+
pub schedule_helper:
35+
ShaMessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
36+
pub message_schedule: ShaMessageScheduleCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U8S>,
3037
}
3138

3239
#[repr(C)]
33-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
34-
pub struct Sha256DigestCols<T> {
35-
pub flags: Sha256FlagsCols<T>,
40+
#[derive(Clone, Copy, Debug, ColsRef)]
41+
pub struct ShaDigestCols<
42+
T,
43+
const WORD_BITS: usize,
44+
const WORD_U8S: usize,
45+
const WORD_U16S: usize,
46+
const HASH_WORDS: usize,
47+
const ROUNDS_PER_ROW: usize,
48+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
49+
const ROW_VAR_CNT: usize,
50+
> {
51+
pub flags: ShaFlagsCols<T, ROW_VAR_CNT>,
3652
/// Will serve as previous hash values for the next block
37-
pub hash: Sha256WorkVarsCols<T>,
38-
pub schedule_helper: Sha256MessageHelperCols<T>,
53+
pub hash: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
54+
pub schedule_helper:
55+
ShaMessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
3956
/// The actual final hash values of the given block
4057
/// Note: the above `hash` will be equal to `final_hash` unless we are on the last block
41-
pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS],
58+
pub final_hash: [[T; WORD_U8S]; HASH_WORDS],
4259
/// The final hash of the previous block
4360
/// Note: will be constrained using interactions with the chip itself
44-
pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS],
61+
pub prev_hash: [[T; WORD_U16S]; HASH_WORDS],
4562
}
4663

4764
#[repr(C)]
48-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
49-
pub struct Sha256MessageScheduleCols<T> {
65+
#[derive(Clone, Copy, Debug, ColsRef)]
66+
pub struct ShaMessageScheduleCols<
67+
T,
68+
const WORD_BITS: usize,
69+
const ROUNDS_PER_ROW: usize,
70+
const WORD_U8S: usize,
71+
> {
5072
/// The message schedule words as 32-bit intergers
51-
pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
73+
pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW],
5274
/// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used freely by wrapper chips
5375
/// Note: carries are represented as 2 bit numbers
54-
pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW],
76+
pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW],
5577
}
5678

5779
#[repr(C)]
58-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
59-
pub struct Sha256WorkVarsCols<T> {
80+
#[derive(Clone, Copy, Debug, ColsRef)]
81+
pub struct ShaWorkVarsCols<
82+
T,
83+
const WORD_BITS: usize,
84+
const ROUNDS_PER_ROW: usize,
85+
const WORD_U16S: usize,
86+
> {
6087
/// `a` and `e` after each iteration as 32-bits
61-
pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
62-
pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
88+
pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW],
89+
pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW],
6390
/// The carry's used for addition during each iteration when computing `a` and `e`
64-
pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
65-
pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
91+
pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW],
92+
pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW],
6693
}
6794

6895
/// These are the columns that are used to help with the message schedule additions
6996
/// Note: these need to be correctly assigned for every row even on padding rows
7097
#[repr(C)]
71-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
72-
pub struct Sha256MessageHelperCols<T> {
98+
#[derive(Clone, Copy, Debug, ColsRef)]
99+
pub struct ShaMessageHelperCols<
100+
T,
101+
const WORD_U16S: usize,
102+
const ROUNDS_PER_ROW: usize,
103+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
104+
> {
73105
/// The following are used to move data forward to constrain the message schedule additions
74106
/// The value of `w` from 3 rounds ago
75-
pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1],
107+
pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE],
76108
/// Here intermediate(i) = w_i + sig_0(w_{i+1})
77109
/// Intermed_t represents the intermediate t rounds ago
78-
pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
79-
pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
80-
pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
110+
pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW],
111+
pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW],
112+
pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW],
81113
}
82114

83115
#[repr(C)]
84-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
85-
pub struct Sha256FlagsCols<T> {
116+
#[derive(Clone, Copy, Debug, ColsRef)]
117+
pub struct ShaFlagsCols<T, const ROW_VAR_CNT: usize> {
86118
pub is_round_row: T,
87119
/// A flag that indicates if the current row is among the first 4 rows of a block
88120
pub is_first_4_rows: T,
89121
pub is_digest_row: T,
90122
pub is_last_block: T,
91123
/// We will encode the row index [0..17) using 5 cells
92-
pub row_idx: [T; SHA256_ROW_VAR_CNT],
124+
//#[length(ROW_VAR_CNT)]
125+
pub row_idx: [T; ROW_VAR_CNT],
93126
/// The global index of the current block
94127
pub global_block_idx: T,
95128
/// Will store the index of the current block in the current message starting from 0
96129
pub local_block_idx: T,
97130
}
98131

99-
impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
132+
impl<O, T: Copy + core::ops::Add<Output = O>, const ROW_VAR_CNT: usize>
133+
ShaFlagsCols<T, ROW_VAR_CNT>
134+
{
100135
pub fn is_not_padding_row(&self) -> O {
101136
self.is_round_row + self.is_digest_row
102137
}
@@ -108,3 +143,16 @@ impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
108143
not(self.is_not_padding_row())
109144
}
110145
}
146+
147+
impl<'a, O, T: Copy + core::ops::Add<Output = O>> ShaFlagsColsRef<'a, T> {
148+
pub fn is_not_padding_row(&self) -> O {
149+
*self.is_round_row + *self.is_digest_row
150+
}
151+
152+
pub fn is_padding_row(&self) -> O
153+
where
154+
O: FieldAlgebra,
155+
{
156+
not(self.is_not_padding_row())
157+
}
158+
}

crates/circuits/sha256-air/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
44
mod air;
55
mod columns;
6+
mod config;
67
mod trace;
78
mod utils;
89

910
pub use air::*;
1011
pub use columns::*;
12+
pub use config::*;
1113
pub use trace::*;
1214
pub use utils::*;
1315

crates/circuits/sha256-air/src/tests.rs

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,57 @@ use openvm_stark_backend::{
1212
config::{StarkGenericConfig, Val},
1313
interaction::InteractionBuilder,
1414
p3_air::{Air, BaseAir},
15-
p3_field::{Field, PrimeField32},
15+
p3_field::{Field, FieldAlgebra, PrimeField32},
1616
prover::types::AirProofInput,
1717
rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
1818
AirRef, Chip, ChipUsageGetter,
1919
};
20-
use openvm_stark_sdk::utils::create_seeded_rng;
20+
use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng};
2121
use rand::Rng;
22+
use sha2::Sha256;
2223

2324
use crate::{
24-
Sha256Air, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK,
25+
Sha256Config, Sha512Config, ShaAir, ShaConfig, ShaFlagsColsRef, ShaFlagsColsRefMut,
26+
ShaPrecomputedValues,
2527
};
2628

2729
// A wrapper AIR purely for testing purposes
2830
#[derive(Clone, Debug)]
29-
pub struct Sha256TestAir {
30-
pub sub_air: Sha256Air,
31+
pub struct ShaTestAir<C: ShaConfig + ShaPrecomputedValues<C::Word>> {
32+
pub sub_air: ShaAir<C>,
3133
}
3234

33-
impl<F: Field> BaseAirWithPublicValues<F> for Sha256TestAir {}
34-
impl<F: Field> PartitionedBaseAir<F> for Sha256TestAir {}
35-
impl<F: Field> BaseAir<F> for Sha256TestAir {
35+
impl<F: Field, C: ShaConfig + ShaPrecomputedValues<C::Word>> BaseAirWithPublicValues<F>
36+
for ShaTestAir<C>
37+
{
38+
}
39+
impl<F: Field, C: ShaConfig + ShaPrecomputedValues<C::Word>> PartitionedBaseAir<F>
40+
for ShaTestAir<C>
41+
{
42+
}
43+
impl<F: Field, C: ShaConfig + ShaPrecomputedValues<C::Word>> BaseAir<F> for ShaTestAir<C> {
3644
fn width(&self) -> usize {
37-
<Sha256Air as BaseAir<F>>::width(&self.sub_air)
45+
<ShaAir<C> as BaseAir<F>>::width(&self.sub_air)
3846
}
3947
}
4048

41-
impl<AB: InteractionBuilder> Air<AB> for Sha256TestAir {
49+
impl<AB: InteractionBuilder, C: ShaConfig + ShaPrecomputedValues<C::Word>> Air<AB>
50+
for ShaTestAir<C>
51+
{
4252
fn eval(&self, builder: &mut AB) {
4353
self.sub_air.eval(builder, 0);
4454
}
4555
}
4656

4757
// A wrapper Chip purely for testing purposes
48-
pub struct Sha256TestChip {
49-
pub air: Sha256TestAir,
58+
pub struct ShaTestChip<C: ShaConfig + ShaPrecomputedValues<C::Word>> {
59+
pub air: ShaTestAir<C>,
5060
pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
51-
pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>,
61+
pub records: Vec<(Vec<u8>, bool)>, // length of inner vec is BLOCK_U8S
5262
}
5363

54-
impl<SC: StarkGenericConfig> Chip<SC> for Sha256TestChip
64+
impl<SC: StarkGenericConfig, C: ShaConfig + ShaPrecomputedValues<C::Word> + 'static> Chip<SC>
65+
for ShaTestChip<C>
5566
where
5667
Val<SC>: PrimeField32,
5768
{
@@ -60,7 +71,7 @@ where
6071
}
6172

6273
fn generate_air_proof_input(self) -> AirProofInput<SC> {
63-
let trace = crate::generate_trace::<Val<SC>>(
74+
let trace = crate::generate_trace::<Val<SC>, C>(
6475
&self.air.sub_air,
6576
self.bitwise_lookup_chip.clone(),
6677
self.records,
@@ -69,33 +80,39 @@ where
6980
}
7081
}
7182

72-
impl ChipUsageGetter for Sha256TestChip {
83+
impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ChipUsageGetter for ShaTestChip<C> {
7384
fn air_name(&self) -> String {
7485
get_air_name(&self.air)
7586
}
7687
fn current_trace_height(&self) -> usize {
77-
self.records.len() * SHA256_ROWS_PER_BLOCK
88+
self.records.len() * C::ROWS_PER_BLOCK
7889
}
7990

8091
fn trace_width(&self) -> usize {
81-
max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH)
92+
max(C::ROUND_WIDTH, C::DIGEST_WIDTH)
8293
}
8394
}
8495

8596
const SELF_BUS_IDX: usize = 28;
86-
#[test]
87-
fn rand_sha256_test() {
97+
fn rand_sha_test<C: ShaConfig + ShaPrecomputedValues<C::Word> + 'static>() {
8898
let mut rng = create_seeded_rng();
8999
let tester = VmChipTestBuilder::default();
90100
let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS);
91101
let bitwise_chip = SharedBitwiseOperationLookupChip::<RV32_CELL_BITS>::new(bitwise_bus);
92102
let len = rng.gen_range(1..100);
93103
let random_records: Vec<_> = (0..len)
94-
.map(|_| (array::from_fn(|_| rng.gen::<u8>()), true))
104+
.map(|_| {
105+
(
106+
(0..C::BLOCK_U8S)
107+
.map(|_| rng.gen::<u8>())
108+
.collect::<Vec<_>>(),
109+
true,
110+
)
111+
})
95112
.collect();
96-
let chip = Sha256TestChip {
97-
air: Sha256TestAir {
98-
sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX),
113+
let chip = ShaTestChip {
114+
air: ShaTestAir {
115+
sub_air: ShaAir::<C>::new(bitwise_bus, SELF_BUS_IDX),
99116
},
100117
bitwise_lookup_chip: bitwise_chip.clone(),
101118
records: random_records,
@@ -104,3 +121,8 @@ fn rand_sha256_test() {
104121
let tester = tester.build().load(chip).load(bitwise_chip).finalize();
105122
tester.simple_test().expect("Verification failed");
106123
}
124+
125+
#[test]
126+
fn rand_sha256_test() {
127+
rand_sha_test::<Sha256Config>();
128+
}

0 commit comments

Comments
 (0)