Skip to content

Commit c714d13

Browse files
Update SHA-256 subair to support SHA-512 and SHA_384
1 parent 399a393 commit c714d13

File tree

8 files changed

+1681
-816
lines changed

8 files changed

+1681
-816
lines changed

crates/circuits/sha2-air/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
[package]
2-
name = "openvm-sha256-air"
2+
name = "openvm-sha2-air"
33
version.workspace = true
44
authors.workspace = true
55
edition.workspace = true
66

77
[dependencies]
88
openvm-circuit-primitives = { workspace = true }
99
openvm-stark-backend = { workspace = true }
10+
openvm-circuit-primitives-derive = { workspace = true }
1011
sha2 = { version = "0.10", features = ["compress"] }
1112
rand.workspace = true
13+
ndarray.workspace = true
1214

1315
[dev-dependencies]
1416
openvm-stark-sdk = { workspace = true }

crates/circuits/sha2-air/src/air.rs

Lines changed: 331 additions & 250 deletions
Large diffs are not rendered by default.
Lines changed: 114 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
//! WARNING: the order of fields in the structs is important, do not change it
22
3-
use openvm_circuit_primitives::{utils::not, AlignedBorrow};
3+
use openvm_circuit_primitives::utils::not;
4+
use openvm_circuit_primitives_derive::ColsRef;
45
use openvm_stark_backend::p3_field::FieldAlgebra;
56

6-
use super::{
7-
SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS,
8-
SHA256_WORD_U16S, SHA256_WORD_U8S,
9-
};
7+
use crate::Sha2Config;
108

11-
/// In each SHA256 block:
12-
/// - First 16 rows use Sha256RoundCols
13-
/// - Final row uses Sha256DigestCols
9+
/// In each SHA block:
10+
/// - First C::ROUND_ROWS rows use ShaRoundCols
11+
/// - Final row uses ShaDigestCols
1412
///
1513
/// Note that for soundness, we require that there is always a padding row after the last digest row
16-
/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus
17-
/// not a power of 2.
14+
/// in the trace. Right now, this is true because the unpadded height is a multiple of 17 (SHA-256)
15+
/// or 21 (SHA-512), and thus not a power of 2.
1816
///
19-
/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields:
17+
/// ShaRoundCols and ShaDigestCols share the same first 3 fields:
2018
/// - flags
2119
/// - work_vars/hash (same type, different name)
2220
/// - schedule_helper
@@ -26,101 +24,131 @@ use super::{
2624
/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional
2725
/// constraints
2826
///
29-
/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs.
27+
/// Note that the `ShaWorkVarsCols` field is used for different purposes in the two structs.
3028
#[repr(C)]
31-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
32-
pub struct Sha256RoundCols<T> {
33-
pub flags: Sha256FlagsCols<T>,
34-
/// Stores the current state of the working variables
35-
pub work_vars: Sha256WorkVarsCols<T>,
36-
pub schedule_helper: Sha256MessageHelperCols<T>,
37-
pub message_schedule: Sha256MessageScheduleCols<T>,
29+
#[derive(Clone, Copy, Debug, ColsRef)]
30+
#[config(Sha2Config)]
31+
pub struct ShaRoundCols<
32+
T,
33+
const WORD_BITS: usize,
34+
const WORD_U8S: usize,
35+
const WORD_U16S: usize,
36+
const ROUNDS_PER_ROW: usize,
37+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
38+
const ROW_VAR_CNT: usize,
39+
> {
40+
pub flags: Sha2FlagsCols<T, ROW_VAR_CNT>,
41+
pub work_vars: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
42+
pub schedule_helper:
43+
Sha2MessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
44+
pub message_schedule: ShaMessageScheduleCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U8S>,
3845
}
3946

4047
#[repr(C)]
41-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
42-
pub struct Sha256DigestCols<T> {
43-
pub flags: Sha256FlagsCols<T>,
44-
/// Will serve as previous hash values for the next block.
45-
/// - on non-last blocks, this is the final hash of the current block
46-
/// - on last blocks, this is the initial state constants, SHA256_H.
47-
/// The work variables constraints are applied on all rows, so `carry_a` and `carry_e`
48-
/// must be filled in with dummy values to ensure these constraints hold.
49-
pub hash: Sha256WorkVarsCols<T>,
50-
pub schedule_helper: Sha256MessageHelperCols<T>,
48+
#[derive(Clone, Copy, Debug, ColsRef)]
49+
#[config(Sha2Config)]
50+
pub struct ShaDigestCols<
51+
T,
52+
const WORD_BITS: usize,
53+
const WORD_U8S: usize,
54+
const WORD_U16S: usize,
55+
const HASH_WORDS: usize,
56+
const ROUNDS_PER_ROW: usize,
57+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
58+
const ROW_VAR_CNT: usize,
59+
> {
60+
pub flags: Sha2FlagsCols<T, ROW_VAR_CNT>,
61+
/// Will serve as previous hash values for the next block
62+
pub hash: ShaWorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
63+
pub schedule_helper:
64+
Sha2MessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
5165
/// The actual final hash values of the given block
5266
/// Note: the above `hash` will be equal to `final_hash` unless we are on the last block
53-
pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS],
67+
pub final_hash: [[T; WORD_U8S]; HASH_WORDS],
5468
/// The final hash of the previous block
5569
/// Note: will be constrained using interactions with the chip itself
56-
pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS],
70+
pub prev_hash: [[T; WORD_U16S]; HASH_WORDS],
5771
}
5872

5973
#[repr(C)]
60-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
61-
pub struct Sha256MessageScheduleCols<T> {
62-
/// The message schedule words as 32-bit integers
74+
#[derive(Clone, Copy, Debug, ColsRef)]
75+
#[config(Sha2Config)]
76+
pub struct ShaMessageScheduleCols<
77+
T,
78+
const WORD_BITS: usize,
79+
const ROUNDS_PER_ROW: usize,
80+
const WORD_U8S: usize,
81+
> {
82+
/// The message schedule words as bits
6383
/// The first 16 words will be the message data
64-
pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
65-
/// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used
66-
/// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as
67-
/// individual bits
68-
pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW],
84+
pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW],
85+
/// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be
86+
/// used freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells
87+
/// as individual bits
88+
pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW],
6989
}
7090

7191
#[repr(C)]
72-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
73-
pub struct Sha256WorkVarsCols<T> {
92+
#[derive(Clone, Copy, Debug, ColsRef)]
93+
#[config(Sha2Config)]
94+
pub struct ShaWorkVarsCols<
95+
T,
96+
const WORD_BITS: usize,
97+
const ROUNDS_PER_ROW: usize,
98+
const WORD_U16S: usize,
99+
> {
74100
/// `a` and `e` after each iteration as 32-bits
75-
pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
76-
pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
101+
pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW],
102+
pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW],
77103
/// The carry's used for addition during each iteration when computing `a` and `e`
78-
pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
79-
pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
104+
pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW],
105+
pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW],
80106
}
81107

82108
/// These are the columns that are used to help with the message schedule additions
83109
/// Note: these need to be correctly assigned for every row even on padding rows
84110
#[repr(C)]
85-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
86-
pub struct Sha256MessageHelperCols<T> {
111+
#[derive(Clone, Copy, Debug, ColsRef)]
112+
#[config(Sha2Config)]
113+
pub struct Sha2MessageHelperCols<
114+
T,
115+
const WORD_U16S: usize,
116+
const ROUNDS_PER_ROW: usize,
117+
const ROUNDS_PER_ROW_MINUS_ONE: usize,
118+
> {
87119
/// The following are used to move data forward to constrain the message schedule additions
88-
/// The value of `w` (message schedule word) from 3 rounds ago
89-
/// In general, `w_i` means `w` from `i` rounds ago
90-
pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1],
120+
/// The value of `w` from 3 rounds ago
121+
pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE],
91122
/// Here intermediate(i) = w_i + sig_0(w_{i+1})
92123
/// Intermed_t represents the intermediate t rounds ago
93124
/// This is needed to constrain the message schedule, since we can only constrain on two rows
94125
/// at a time
95-
pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
96-
pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
97-
pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
126+
pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW],
127+
pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW],
128+
pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW],
98129
}
99130

100131
#[repr(C)]
101-
#[derive(Clone, Copy, Debug, AlignedBorrow)]
102-
pub struct Sha256FlagsCols<T> {
103-
/// A flag that indicates if the current row is among the first 16 rows of a block.
132+
#[derive(Clone, Copy, Debug, ColsRef)]
133+
#[config(Sha2Config)]
134+
pub struct Sha2FlagsCols<T, const ROW_VAR_CNT: usize> {
104135
pub is_round_row: T,
105-
/// A flag that indicates if the current row is among the first 4 rows of a block.
136+
/// A flag that indicates if the current row is among the first 4 rows of a block (the message
137+
/// rows)
106138
pub is_first_4_rows: T,
107-
/// A flag that indicates if the current row is the last (17th) row of a block.
108139
pub is_digest_row: T,
109-
// A flag that indicates if the current row is the last block of the message.
110-
// This flag is only used in digest rows.
111140
pub is_last_block: T,
112141
/// We will encode the row index [0..17) using 5 cells
113-
pub row_idx: [T; SHA256_ROW_VAR_CNT],
114-
/// The index of the current block in the trace starting at 1.
115-
/// Set to 0 on padding rows.
142+
pub row_idx: [T; ROW_VAR_CNT],
143+
/// The global index of the current block
116144
pub global_block_idx: T,
117-
/// The index of the current block in the current message starting at 0.
118-
/// Resets after every message.
119-
/// Set to 0 on padding rows.
145+
/// Will store the index of the current block in the current message starting from 0
120146
pub local_block_idx: T,
121147
}
122148

123-
impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
149+
impl<O, T: Copy + core::ops::Add<Output = O>, const ROW_VAR_CNT: usize>
150+
Sha2FlagsCols<T, ROW_VAR_CNT>
151+
{
124152
// This refers to the padding rows that are added to the air to make the trace length a power of
125153
// 2. Not to be confused with the padding added to messages as part of the SHA hash
126154
// function.
@@ -138,3 +166,22 @@ impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
138166
not(self.is_not_padding_row())
139167
}
140168
}
169+
170+
impl<O, T: Copy + core::ops::Add<Output = O>> Sha2FlagsColsRef<'_, T> {
171+
// This refers to the padding rows that are added to the air to make the trace length a power of
172+
// 2. Not to be confused with the padding added to messages as part of the SHA hash
173+
// function.
174+
pub fn is_not_padding_row(&self) -> O {
175+
*self.is_round_row + *self.is_digest_row
176+
}
177+
178+
// This refers to the padding rows that are added to the air to make the trace length a power of
179+
// 2. Not to be confused with the padding added to messages as part of the SHA hash
180+
// function.
181+
pub fn is_padding_row(&self) -> O
182+
where
183+
O: FieldAlgebra,
184+
{
185+
not(self.is_not_padding_row())
186+
}
187+
}

0 commit comments

Comments
 (0)