Skip to content

Commit f856c44

Browse files
Updated sha256 air to use sha-macros (ColsRef structs)
1 parent 33bf3e6 commit f856c44

File tree

8 files changed

+1073
-681
lines changed

8 files changed

+1073
-681
lines changed

crates/circuits/sha256-air/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ openvm-circuit-primitives = { workspace = true }
99
openvm-stark-backend = { workspace = true }
1010
sha2 = { version = "0.10", features = ["compress"] }
1111
rand.workspace = true
12+
openvm-sha-macros = { workspace = true }
13+
ndarray = "0.16"
1214

1315
[dev-dependencies]
1416
openvm-stark-sdk = { workspace = true }

crates/circuits/sha256-air/src/air.rs

Lines changed: 265 additions & 202 deletions
Large diffs are not rendered by default.
Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
//! WARNING: the order of fields in the structs is important, do not change it
22
3-
use openvm_circuit_primitives::{utils::not, AlignedBorrow};
3+
use openvm_circuit_primitives::utils::not;
4+
use openvm_sha_macros::ColsRef;
45
use openvm_stark_backend::p3_field::FieldAlgebra;
56

6-
use super::{
7-
SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS,
8-
SHA256_WORD_U16S, SHA256_WORD_U8S,
9-
};
7+
use crate::ShaConfig;
108

119
/// In each SHA256 block:
1210
/// - First 16 rows use Sha256RoundCols
@@ -26,77 +24,111 @@ use super::{
2624
///
2725
/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs.
2826
#[repr(C)]
29-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
30-
pub struct Sha256RoundCols<T> {
31-
pub flags: Sha256FlagsCols<T>,
27+
#[derive(Clone, Copy, Debug, ColsRef)]
28+
pub struct ShaRoundCols<
29+
T,
30+
const WORD_BITS: usize,
31+
const WORD_U8S: usize,
32+
const WORD_U16S: usize,
33+
const ROUNDS_PER_ROW: usize,
34+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
35+
const ROW_VAR_CNT: usize,
36+
> {
37+
pub flags: ShaFlagsCols<T, ROW_VAR_CNT>,
3238
/// Stores the current state of the working variables
33-
pub work_vars: Sha256WorkVarsCols<T>,
34-
pub schedule_helper: Sha256MessageHelperCols<T>,
35-
pub message_schedule: Sha256MessageScheduleCols<T>,
39+
pub work_vars: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
40+
pub schedule_helper:
41+
ShaMessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
42+
pub message_schedule: ShaMessageScheduleCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U8S>,
3643
}
3744

3845
#[repr(C)]
39-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
40-
pub struct Sha256DigestCols<T> {
41-
pub flags: Sha256FlagsCols<T>,
46+
#[derive(Clone, Copy, Debug, ColsRef)]
47+
pub struct ShaDigestCols<
48+
T,
49+
const WORD_BITS: usize,
50+
const WORD_U8S: usize,
51+
const WORD_U16S: usize,
52+
const HASH_WORDS: usize,
53+
const ROUNDS_PER_ROW: usize,
54+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
55+
const ROW_VAR_CNT: usize,
56+
> {
57+
pub flags: ShaFlagsCols<T, ROW_VAR_CNT>,
4258
/// Will serve as previous hash values for the next block.
4359
/// - on non-last blocks, this is the final hash of the current block
4460
/// - on last blocks, this is the initial state constants, SHA256_H.
4561
/// The work variables constraints are applied on all rows, so `carry_a` and `carry_e`
4662
/// must be filled in with dummy values to ensure these constraints hold.
47-
pub hash: Sha256WorkVarsCols<T>,
48-
pub schedule_helper: Sha256MessageHelperCols<T>,
63+
pub hash: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
64+
pub schedule_helper:
65+
ShaMessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
4966
/// The actual final hash values of the given block
5067
/// Note: the above `hash` will be equal to `final_hash` unless we are on the last block
51-
pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS],
68+
pub final_hash: [[T; WORD_U8S]; HASH_WORDS],
5269
/// The final hash of the previous block
5370
/// Note: will be constrained using interactions with the chip itself
54-
pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS],
71+
pub prev_hash: [[T; WORD_U16S]; HASH_WORDS],
5572
}
5673

5774
#[repr(C)]
58-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
59-
pub struct Sha256MessageScheduleCols<T> {
60-
/// The message schedule words as 32-bit integers
61-
/// The first 16 words will be the message data
62-
pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
75+
#[derive(Clone, Copy, Debug, ColsRef)]
76+
pub struct ShaMessageScheduleCols<
77+
T,
78+
const WORD_BITS: usize,
79+
const ROUNDS_PER_ROW: usize,
80+
const WORD_U8S: usize,
81+
> {
82+
/// The message schedule words as C::WORD_BITS-bit integers
83+
/// The first 16 rows will be the message data
84+
pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW],
6385
/// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used freely by wrapper chips
6486
/// Note: carries are 2 bit numbers represented using 2 cells as individual bits
65-
pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW],
87+
pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW],
6688
}
6789

6890
#[repr(C)]
69-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
70-
pub struct Sha256WorkVarsCols<T> {
91+
#[derive(Clone, Copy, Debug, ColsRef)]
92+
pub struct ShaWorkVarsCols<
93+
T,
94+
const WORD_BITS: usize,
95+
const ROUNDS_PER_ROW: usize,
96+
const WORD_U16S: usize,
97+
> {
7198
/// `a` and `e` after each iteration as 32-bits
72-
pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
73-
pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
99+
pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW],
100+
pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW],
74101
/// The carry's used for addition during each iteration when computing `a` and `e`
75-
pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
76-
pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
102+
pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW],
103+
pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW],
77104
}
78105

79106
/// These are the columns that are used to help with the message schedule additions
80107
/// Note: these need to be correctly assigned for every row even on padding rows
81108
#[repr(C)]
82-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
83-
pub struct Sha256MessageHelperCols<T> {
109+
#[derive(Clone, Copy, Debug, ColsRef)]
110+
pub struct ShaMessageHelperCols<
111+
T,
112+
const WORD_U16S: usize,
113+
const ROUNDS_PER_ROW: usize,
114+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
115+
> {
84116
/// The following are used to move data forward to constrain the message schedule additions
85117
/// The value of `w` (message schedule word) from 3 rounds ago
86118
/// In general, `w_i` means `w` from `i` rounds ago
87-
pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1],
119+
pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE],
88120
/// Here intermediate(i) = w_i + sig_0(w_{i+1})
89121
/// Intermed_t represents the intermediate t rounds ago
90122
/// This is needed to constrain the message schedule, since we can only constrain on two rows at a time
91-
pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
92-
pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
93-
pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
123+
pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW],
124+
pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW],
125+
pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW],
94126
}
95127

96128
#[repr(C)]
97-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
98-
pub struct Sha256FlagsCols<T> {
99-
/// A flag that indicates if the current row is among the first 16 rows of a block.
129+
#[derive(Clone, Copy, Debug, ColsRef)]
130+
pub struct ShaFlagsCols<T, const ROW_VAR_CNT: usize> {
131+
/// A flag that indicates if the current row is among the first C::ROUND_ROWS rows of a block.
100132
pub is_round_row: T,
101133
/// A flag that indicates if the current row is among the first 4 rows of a block.
102134
pub is_first_4_rows: T,
@@ -106,7 +138,8 @@ pub struct Sha256FlagsCols<T> {
106138
// This flag is only used in digest rows.
107139
pub is_last_block: T,
108140
/// We will encode the row index [0..17) using 5 cells
109-
pub row_idx: [T; SHA256_ROW_VAR_CNT],
141+
//#[length(ROW_VAR_CNT)]
142+
pub row_idx: [T; ROW_VAR_CNT],
110143
/// The index of the current block in the trace starting at 1.
111144
/// Set to 0 on padding rows.
112145
pub global_block_idx: T,
@@ -116,7 +149,9 @@ pub struct Sha256FlagsCols<T> {
116149
pub local_block_idx: T,
117150
}
118151

119-
impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
152+
impl<O, T: Copy + core::ops::Add<Output = O>, const ROW_VAR_CNT: usize>
153+
ShaFlagsCols<T, ROW_VAR_CNT>
154+
{
120155
// This refers to the padding rows that are added to the air to make the trace length a power of 2.
121156
// Not to be confused with the padding added to messages as part of the SHA hash function.
122157
pub fn is_not_padding_row(&self) -> O {
@@ -132,3 +167,19 @@ impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
132167
not(self.is_not_padding_row())
133168
}
134169
}
170+
171+
// We need to implement this for the ColsRef type as well
172+
impl<'a, O, T: Copy + core::ops::Add<Output = O>> ShaFlagsColsRef<'a, T> {
173+
pub fn is_not_padding_row(&self) -> O {
174+
*self.is_round_row + *self.is_digest_row
175+
}
176+
177+
// This refers to the padding rows that are added to the air to make the trace length a power of 2.
178+
// Not to be confused with the padding added to messages as part of the SHA hash function.
179+
pub fn is_padding_row(&self) -> O
180+
where
181+
O: FieldAlgebra,
182+
{
183+
not(self.is_not_padding_row())
184+
}
185+
}

crates/circuits/sha256-air/src/config.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ pub trait ShaConfig: Send + Sync + Clone {
3030
const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS;
3131
/// Number of rows per block
3232
const ROWS_PER_BLOCK: usize;
33-
/// Number of rounds per row
33+
/// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK
3434
const ROUNDS_PER_ROW: usize;
3535
/// Number of rounds per row minus one (needed for one of the column structs)
3636
const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1;
37-
/// Number of rounds per block
37+
/// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW
3838
const ROUNDS_PER_BLOCK: usize;
39+
/// Number of rows used to constrain rounds
40+
const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW;
3941
/// Number of words in a SHA hash
4042
const HASH_WORDS: usize;
4143
/// Number of vars needed to encode the row index with [Encoder]

crates/circuits/sha256-air/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
44
mod air;
55
mod columns;
6+
mod config;
67
mod trace;
78
mod utils;
89

910
pub use air::*;
1011
pub use columns::*;
12+
pub use config::*;
1113
pub use trace::*;
1214
pub use utils::*;
1315

0 commit comments

Comments
 (0)