Skip to content

Commit dc8938d

Browse files
Added sha512 support to the air
1 parent 43f3715 commit dc8938d

File tree

8 files changed

+187
-119
lines changed

8 files changed

+187
-119
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/circuits/sha-macros/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ fn make_struct(struct_info: StructInfo) -> proc_macro2::TokenStream {
172172
fn make_from_mut(struct_info: StructInfo) -> Result<proc_macro2::TokenStream, String> {
173173
let StructInfo {
174174
name,
175-
vis,
175+
vis: _,
176176
generic_type,
177-
field_infos,
177+
field_infos: _,
178178
fields,
179-
from_args,
180-
derive_clone,
179+
from_args: _,
180+
derive_clone: _,
181181
} = struct_info;
182182

183183
let fields = match fields {

crates/circuits/sha256-air/src/air.rs

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
3737
pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self {
3838
Self {
3939
bitwise_lookup_bus,
40-
row_idx_encoder: Encoder::new(18, 2, false),
40+
row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), // + 1 for dummy (padding) rows
4141
bus: PermutationCheckBus::new(self_bus_idx),
4242
_phantom: PhantomData,
4343
}
@@ -91,29 +91,35 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
9191

9292
self.row_idx_encoder
9393
.eval(builder, local_cols.flags.row_idx.to_slice().unwrap());
94-
builder.assert_one(
95-
self.row_idx_encoder
96-
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=17),
97-
);
94+
builder.assert_one(self.row_idx_encoder.contains_flag_range::<AB>(
95+
local_cols.flags.row_idx.to_slice().unwrap(),
96+
0..=C::ROWS_PER_BLOCK,
97+
));
9898
builder.assert_eq(
9999
self.row_idx_encoder
100100
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3),
101101
*flags.is_first_4_rows,
102102
);
103103
builder.assert_eq(
104-
self.row_idx_encoder
105-
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=15),
104+
self.row_idx_encoder.contains_flag_range::<AB>(
105+
local_cols.flags.row_idx.to_slice().unwrap(),
106+
0..=C::ROUND_ROWS - 1,
107+
),
106108
*flags.is_round_row,
107109
);
108110
builder.assert_eq(
109-
self.row_idx_encoder
110-
.contains_flag::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), &[16]),
111+
self.row_idx_encoder.contains_flag::<AB>(
112+
local_cols.flags.row_idx.to_slice().unwrap(),
113+
&[C::ROUND_ROWS],
114+
),
111115
*flags.is_digest_row,
112116
);
113-
// If padding row we want the row_idx to be 17
117+
// If padding row we want the row_idx to be C::ROWS_PER_BLOCK
114118
builder.assert_eq(
115-
self.row_idx_encoder
116-
.contains_flag::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), &[17]),
119+
self.row_idx_encoder.contains_flag::<AB>(
120+
local_cols.flags.row_idx.to_slice().unwrap(),
121+
&[C::ROWS_PER_BLOCK],
122+
),
117123
flags.is_padding_row(),
118124
);
119125

@@ -130,14 +136,14 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
130136
/// Implements constraints for a digest row that ensure proper state transitions between blocks
131137
/// This validates that:
132138
/// The work variables are correctly initialized for the next message block
133-
/// For the last message block, the initial state matches SHA256_H constants
139+
/// For the last message block, the initial state matches SHA_H constants
134140
fn eval_digest_row<AB: InteractionBuilder>(
135141
&self,
136142
builder: &mut AB,
137143
local: ShaRoundColsRef<AB::Var>,
138144
next: ShaDigestColsRef<AB::Var>,
139145
) {
140-
// Check that if this is the last row of a message or an inpadding row, the hash should be the [SHA256_H]
146+
// Check that if this is the last row of a message or an inpadding row, the hash should be the [SHA_H]
141147
for i in 0..C::ROUNDS_PER_ROW {
142148
let a = next.hash.a.row(i).mapv(|x| x.into()).to_vec();
143149
let e = next.hash.e.row(i).mapv(|x| x.into()).to_vec();
@@ -146,7 +152,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
146152
let a_limb = compose::<AB::Expr>(&a[j * 16..(j + 1) * 16], 1);
147153
let e_limb = compose::<AB::Expr>(&e[j * 16..(j + 1) * 16], 1);
148154

149-
// If it is a padding row or the last row of a message, the `hash` should be the [SHA256_H]
155+
// If it is a padding row or the last row of a message, the `hash` should be the [SHA_H]
150156
builder
151157
.when(
152158
next.flags.is_padding_row()
@@ -278,24 +284,24 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
278284
// Constrain how much the row index changes by
279285
// round->round: 1
280286
// round->digest: 1
281-
// digest->round: -16 // TODO: sha512
287+
// digest->round: -C::ROUND_ROWS
282288
// digest->padding: 1
283289
// padding->padding: 0
284290
// Other transitions are not allowed by the above constraints
285291
let delta = *local_cols.flags.is_round_row * AB::Expr::ONE
286292
+ *local_cols.flags.is_digest_row
287293
* *next_cols.flags.is_round_row
288-
* AB::Expr::from_canonical_u32(16)
294+
* AB::Expr::from_canonical_usize(C::ROUND_ROWS)
289295
* AB::Expr::NEG_ONE
290296
+ *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE;
291297

292298
let local_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
293299
local_cols.flags.row_idx.to_slice().unwrap(),
294-
&(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
300+
&(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::<Vec<_>>(),
295301
);
296302
let next_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
297303
next_cols.flags.row_idx.to_slice().unwrap(),
298-
&(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
304+
&(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::<Vec<_>>(),
299305
);
300306

301307
builder
@@ -458,20 +464,22 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
458464
}
459465

460466
// Constrain intermed for `next` row
461-
// We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for other rows
467+
// We will only constrain intermed_12 for rows [3, 14], and let it unconstrained for other rows
462468
// Other rows should put the needed value in intermed_12 to make the below summation constraint hold
463-
let is_row_3_14 = self
464-
.row_idx_encoder
465-
.contains_flag_range::<AB>(next.flags.row_idx.to_slice().unwrap(), 3..=14);
466-
// We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other rows
467-
let is_row_2_13 = self
468-
.row_idx_encoder
469-
.contains_flag_range::<AB>(next.flags.row_idx.to_slice().unwrap(), 2..=13);
469+
let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::<AB>(
470+
next.flags.row_idx.to_slice().unwrap(),
471+
3..=C::ROUND_ROWS - 2,
472+
);
473+
// We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 2], and let it unconstrained for other rows
474+
let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::<AB>(
475+
next.flags.row_idx.to_slice().unwrap(),
476+
2..=C::ROUND_ROWS - 3,
477+
);
470478
for i in 0..C::ROUNDS_PER_ROW {
471479
// w_idx
472480
let w_idx = w.row(i).mapv(|x| x.into()).to_vec();
473481
// sig_0(w_{idx+1})
474-
let sig_w = small_sig0_field::<AB::Expr>(w.row(i + 1).as_slice().unwrap());
482+
let sig_w = small_sig0_field::<AB::Expr, C>(w.row(i + 1).as_slice().unwrap());
475483
for j in 0..C::WORD_U16S {
476484
let w_idx_limb = compose::<AB::Expr>(&w_idx[j * 16..(j + 1) * 16], 1);
477485
let sig_w_limb = compose::<AB::Expr>(&sig_w[j * 16..(j + 1) * 16], 1);
@@ -483,12 +491,12 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
483491
w_idx_limb + sig_w_limb,
484492
);
485493

486-
builder.when(is_row_2_13.clone()).assert_eq(
494+
builder.when(is_row_intermed_8.clone()).assert_eq(
487495
next.schedule_helper.intermed_8[[i, j]],
488496
local.schedule_helper.intermed_4[[i, j]],
489497
);
490498

491-
builder.when(is_row_3_14.clone()).assert_eq(
499+
builder.when(is_row_intermed_12.clone()).assert_eq(
492500
next.schedule_helper.intermed_12[[i, j]],
493501
local.schedule_helper.intermed_8[[i, j]],
494502
);
@@ -524,7 +532,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
524532
constraint_word_addition::<_, C>(
525533
// Note: here we can't do a conditional check because the degree of sum is already 3
526534
&mut builder.when_transition(),
527-
&[&small_sig1_field::<AB::Expr>(
535+
&[&small_sig1_field::<AB::Expr, C>(
528536
w.row(i + 2).as_slice().unwrap(),
529537
)],
530538
&[&w_7, intermed_16.as_slice().unwrap()],
@@ -533,13 +541,13 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
533541
);
534542

535543
for j in 0..C::WORD_U16S {
536-
// When on rows 4..16 message schedule carries should be 0 or 1
537-
let is_row_4_15 = *next.flags.is_round_row - *next.flags.is_first_4_rows;
544+
// When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1
545+
let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows;
538546
builder
539-
.when(is_row_4_15.clone())
547+
.when(is_row_4_or_more.clone())
540548
.assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]);
541549
builder
542-
.when(is_row_4_15)
550+
.when(is_row_4_or_more)
543551
.assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]);
544552
}
545553
// Constrain w being composed of bits
@@ -551,7 +559,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
551559
}
552560
}
553561

554-
/// Constrain the work vars on `next` row according to the sha256 documentation
562+
/// Constrain the work vars on `next` row according to the sha documentation
555563
/// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf]
556564
fn eval_work_vars<'a, AB: InteractionBuilder>(
557565
&self,
@@ -588,11 +596,12 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
588596
) * *next.flags.is_round_row
589597
})
590598
.collect::<Vec<_>>();
599+
591600
let k_limbs = (0..C::WORD_U16S)
592601
.map(|j| {
593602
self.row_idx_encoder.flag_with_val::<AB>(
594603
next.flags.row_idx.to_slice().unwrap(),
595-
&(0..16)
604+
&(0..C::ROUND_ROWS)
596605
.map(|rw_idx| {
597606
(
598607
rw_idx,
@@ -613,13 +622,13 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
613622
builder,
614623
&[
615624
e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
616-
&big_sig1_field::<AB::Expr>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
625+
&big_sig1_field::<AB::Expr, C>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
617626
&ch_field::<AB::Expr>(
618627
e.row(i + 3).as_slice().unwrap(),
619628
e.row(i + 2).as_slice().unwrap(),
620629
e.row(i + 1).as_slice().unwrap(),
621630
), // Ch of previous `e`, `f`, `g`
622-
&big_sig0_field::<AB::Expr>(a.row(i + 3).as_slice().unwrap()), // sig_0 of previous `a`
631+
&big_sig0_field::<AB::Expr, C>(a.row(i + 3).as_slice().unwrap()), // sig_0 of previous `a`
623632
&maj_field::<AB::Expr>(
624633
a.row(i + 3).as_slice().unwrap(),
625634
a.row(i + 2).as_slice().unwrap(),
@@ -637,9 +646,9 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
637646
constraint_word_addition::<_, C>(
638647
builder,
639648
&[
640-
&a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d`
641-
&e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
642-
&big_sig1_field::<AB::Expr>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
649+
a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d`
650+
e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
651+
&big_sig1_field::<AB::Expr, C>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
643652
&ch_field::<AB::Expr>(
644653
e.row(i + 3).as_slice().unwrap(),
645654
e.row(i + 2).as_slice().unwrap(),

crates/circuits/sha256-air/src/columns.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ pub struct ShaMessageScheduleCols<
8282
/// The message schedule words as C::WORD_BITS-bit integers
8383
/// The first 16 rows will be the message data
8484
pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW],
85-
/// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used freely by wrapper chips
85+
/// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be used freely by wrapper chips
8686
/// Note: carries are 2 bit numbers represented using 2 cells as individual bits
8787
pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW],
8888
}

crates/circuits/sha256-air/src/config.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ pub trait ShaConfig: Send + Sync + Clone {
3232
const ROWS_PER_BLOCK: usize;
3333
/// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK
3434
const ROUNDS_PER_ROW: usize;
35+
/// Number of rows used for the sha rounds
36+
const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW;
37+
/// Number of rows used for the message
38+
const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW;
3539
/// Number of rounds per row minus one (needed for one of the column structs)
3640
const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1;
3741
/// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW
@@ -121,6 +125,7 @@ pub const SHA256_INVALID_CARRY_E: [[u32; Sha256Config::WORD_U16S]; Sha256Config:
121125
[719953922, 1888246508],
122126
[194580482, 1075725211],
123127
];
128+
124129
/// SHA256 constant K's
125130
pub const SHA256_K: [u32; 64] = [
126131
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
@@ -163,25 +168,34 @@ impl ShaConfig for Sha512Config {
163168
/// Number of words in a SHA512 block
164169
const BLOCK_WORDS: usize = 16;
165170
/// Number of rows per block
166-
const ROWS_PER_BLOCK: usize = 21; // SHA-512 has 80 rounds, so needs more rows
171+
const ROWS_PER_BLOCK: usize = 21;
167172
/// Number of rounds per row
168173
const ROUNDS_PER_ROW: usize = 4;
169174
/// Number of rounds per block
170175
const ROUNDS_PER_BLOCK: usize = 80;
171176
/// Number of words in a SHA512 hash
172177
const HASH_WORDS: usize = 8;
173178
/// Number of vars needed to encode the row index with [Encoder]
174-
const ROW_VAR_CNT: usize = 5;
179+
const ROW_VAR_CNT: usize = 6;
175180
}
176181

177-
// TODO: fill in these constants
178-
179182
/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows
180183
/// To optimize the trace generation of invalid rows, we have those values precomputed here
181184
pub(crate) const SHA512_INVALID_CARRY_A: [[u32; Sha512Config::WORD_U16S];
182-
Sha512Config::ROUNDS_PER_ROW] = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]];
185+
Sha512Config::ROUNDS_PER_ROW] = [
186+
[55971842, 827997017, 993005918, 512731953],
187+
[227512322, 1697529235, 1936430385, 940122990],
188+
[1939875843, 1173318562, 826201586, 1513494849],
189+
[891955202, 1732283693, 1736658755, 223514501],
190+
];
191+
183192
pub(crate) const SHA512_INVALID_CARRY_E: [[u32; Sha512Config::WORD_U16S];
184-
Sha512Config::ROUNDS_PER_ROW] = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]];
193+
Sha512Config::ROUNDS_PER_ROW] = [
194+
[1384427522, 1509509767, 153131516, 102514978],
195+
[1527552003, 1041677071, 837289497, 843522538],
196+
[775188482, 1620184630, 744892564, 892058728],
197+
[1801267202, 1393118048, 1846108940, 830635531],
198+
];
185199

186200
/// SHA512 constant K's
187201
pub const SHA512_K: [u64; 80] = [

crates/circuits/sha256-air/src/tests.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ use openvm_stark_backend::{
1818
rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
1919
AirRef, Chip, ChipUsageGetter,
2020
};
21-
use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng};
21+
use openvm_stark_sdk::utils::create_seeded_rng;
2222
use rand::Rng;
23-
use sha2::Sha256;
2423

2524
use crate::{
2625
compose, small_sig0_field, Sha256Config, Sha512Config, ShaAir, ShaConfig, ShaFlagsColsRef,
@@ -128,6 +127,11 @@ fn rand_sha256_test() {
128127
rand_sha_test::<Sha256Config>();
129128
}
130129

130+
#[test]
131+
fn rand_sha512_test() {
132+
rand_sha_test::<Sha512Config>();
133+
}
134+
131135
// A wrapper Chip to test that the final_hash is properly constrained.
132136
// This chip implements a malicious trace gen that violates the final_hash constraints.
133137
pub struct ShaTestBadFinalHashChip<C: ShaConfig + ShaPrecomputedValues<C::Word>> {

0 commit comments

Comments
 (0)