Skip to content

Commit 71b60bd

Browse files
Added sha512 support to the air
1 parent ffa2cf5 commit 71b60bd

File tree

8 files changed

+452
-298
lines changed

8 files changed

+452
-298
lines changed

Cargo.lock

Lines changed: 240 additions & 159 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/circuits/sha-macros/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ fn make_struct(struct_info: StructInfo) -> proc_macro2::TokenStream {
172172
fn make_from_mut(struct_info: StructInfo) -> Result<proc_macro2::TokenStream, String> {
173173
let StructInfo {
174174
name,
175-
vis,
175+
vis: _,
176176
generic_type,
177-
field_infos,
177+
field_infos: _,
178178
fields,
179-
from_args,
180-
derive_clone,
179+
from_args: _,
180+
derive_clone: _,
181181
} = struct_info;
182182

183183
let fields = match fields {

crates/circuits/sha256-air/src/air.rs

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
3636
pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: usize) -> Self {
3737
Self {
3838
bitwise_lookup_bus,
39-
row_idx_encoder: Encoder::new(17, 2, false),
39+
row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK, 2, false),
4040
bus_idx: self_bus_idx,
4141
_phantom: PhantomData,
4242
}
@@ -91,29 +91,35 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
9191

9292
self.row_idx_encoder
9393
.eval(builder, local_cols.flags.row_idx.to_slice().unwrap());
94-
builder.assert_one(
95-
self.row_idx_encoder
96-
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=17),
97-
);
94+
builder.assert_one(self.row_idx_encoder.contains_flag_range::<AB>(
95+
local_cols.flags.row_idx.to_slice().unwrap(),
96+
0..=C::ROWS_PER_BLOCK,
97+
));
9898
builder.assert_eq(
9999
self.row_idx_encoder
100100
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3),
101101
*flags.is_first_4_rows,
102102
);
103103
builder.assert_eq(
104-
self.row_idx_encoder
105-
.contains_flag_range::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), 0..=15),
104+
self.row_idx_encoder.contains_flag_range::<AB>(
105+
local_cols.flags.row_idx.to_slice().unwrap(),
106+
0..=C::ROUND_ROWS - 1,
107+
),
106108
*flags.is_round_row,
107109
);
108110
builder.assert_eq(
109-
self.row_idx_encoder
110-
.contains_flag::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), &[16]),
111+
self.row_idx_encoder.contains_flag::<AB>(
112+
local_cols.flags.row_idx.to_slice().unwrap(),
113+
&[C::ROUND_ROWS],
114+
),
111115
*flags.is_digest_row,
112116
);
113-
// If padding row we want the row_idx to be 17
117+
// If padding row we want the row_idx to be C::ROWS_PER_BLOCK
114118
builder.assert_eq(
115-
self.row_idx_encoder
116-
.contains_flag::<AB>(local_cols.flags.row_idx.to_slice().unwrap(), &[17]),
119+
self.row_idx_encoder.contains_flag::<AB>(
120+
local_cols.flags.row_idx.to_slice().unwrap(),
121+
&[C::ROWS_PER_BLOCK],
122+
),
117123
flags.is_padding_row(),
118124
);
119125

@@ -131,13 +137,13 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
131137
/// Implements constraints for a digest row that ensure proper state transitions between blocks
132138
/// This validates that:
133139
/// The work variables are correctly initialized for the next message block
134-
/// For the last message block, the initial state matches SHA256_H constants
140+
/// For the last message block, the initial state matches SHA_H constants
135141
fn eval_digest_row<AB: InteractionBuilder>(
136142
&self,
137143
builder: &mut AB,
138144
local: ShaDigestColsRef<AB::Var>,
139145
) {
140-
// Check that if this is the last row of a message or an inpadding row, the hash should be the [SHA256_H]
146+
// Check that if this is the last row of a message or an inpadding row, the hash should be the [SHA_H]
141147
for i in 0..C::ROUNDS_PER_ROW {
142148
let a = local.hash.a.row(i).mapv(|x| x.into()).to_vec();
143149
let e = local.hash.e.row(i).mapv(|x| x.into()).to_vec();
@@ -146,7 +152,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
146152
let a_limb = compose::<AB::Expr>(&a[j * 16..(j + 1) * 16], 1);
147153
let e_limb = compose::<AB::Expr>(&e[j * 16..(j + 1) * 16], 1);
148154

149-
// If it is a padding row or the last row of a message, the `hash` should be the [SHA256_H]
155+
// If it is a padding row or the last row of a message, the `hash` should be the [SHA_H]
150156
builder
151157
.when(
152158
local.flags.is_padding_row()
@@ -244,24 +250,24 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
244250
// Constrin how much the row index changes by
245251
// round->round: 1
246252
// round->digest: 1
247-
// digest->round: -16 // TODO: sha512
253+
// digest->round: -C::ROUND_ROWS
248254
// digest->padding: 1
249255
// padding->padding: 0
250256
// Other transitions are not allowed by the above
251257
let delta = *local_cols.flags.is_round_row * AB::Expr::ONE
252258
+ *local_cols.flags.is_digest_row
253259
* *next_cols.flags.is_round_row
254-
* AB::Expr::from_canonical_u32(16)
260+
* AB::Expr::from_canonical_usize(C::ROUND_ROWS)
255261
* AB::Expr::NEG_ONE
256262
+ *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE;
257263

258264
let local_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
259265
local_cols.flags.row_idx.to_slice().unwrap(),
260-
&(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
266+
&(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::<Vec<_>>(),
261267
);
262268
let next_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
263269
next_cols.flags.row_idx.to_slice().unwrap(),
264-
&(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
270+
&(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::<Vec<_>>(),
265271
);
266272

267273
builder
@@ -411,20 +417,22 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
411417
}
412418

413419
// Constrain intermed for `next` row
414-
// We will only constrain intermed_12 for rows [3, 14], and let it unconstrained for other rows
420+
// We will only constrain intermed_12 for rows [3, C::ROUND_ROWS - 1], and let it unconstrained for other rows
415421
// Other rows should put the needed value in intermed_12 to make the below summation constraint hold
416-
let is_row_3_14 = self
417-
.row_idx_encoder
418-
.contains_flag_range::<AB>(next.flags.row_idx.to_slice().unwrap(), 3..=14);
419-
// We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other rows
420-
let is_row_2_13 = self
421-
.row_idx_encoder
422-
.contains_flag_range::<AB>(next.flags.row_idx.to_slice().unwrap(), 2..=13);
422+
let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::<AB>(
423+
next.flags.row_idx.to_slice().unwrap(),
424+
3..=C::ROUND_ROWS - 2,
425+
);
426+
// We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 2], and let it unconstrained for other rows
427+
let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::<AB>(
428+
next.flags.row_idx.to_slice().unwrap(),
429+
2..=C::ROUND_ROWS - 3,
430+
);
423431
for i in 0..C::ROUNDS_PER_ROW {
424432
// w_idx
425433
let w_idx = w.row(i).mapv(|x| x.into()).to_vec();
426434
// sig_0(w_{idx+1})
427-
let sig_w = small_sig0_field::<AB::Expr>(w.row(i + 1).as_slice().unwrap());
435+
let sig_w = small_sig0_field::<AB::Expr, C>(w.row(i + 1).as_slice().unwrap());
428436
for j in 0..C::WORD_U16S {
429437
let w_idx_limb = compose::<AB::Expr>(&w_idx[j * 16..(j + 1) * 16], 1);
430438
let sig_w_limb = compose::<AB::Expr>(&sig_w[j * 16..(j + 1) * 16], 1);
@@ -434,12 +442,12 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
434442
w_idx_limb + sig_w_limb,
435443
);
436444

437-
builder.when(is_row_2_13.clone()).assert_eq(
445+
builder.when(is_row_intermed_8.clone()).assert_eq(
438446
next.schedule_helper.intermed_8[[i, j]],
439447
local.schedule_helper.intermed_4[[i, j]],
440448
);
441449

442-
builder.when(is_row_3_14.clone()).assert_eq(
450+
builder.when(is_row_intermed_12.clone()).assert_eq(
443451
next.schedule_helper.intermed_12[[i, j]],
444452
local.schedule_helper.intermed_8[[i, j]],
445453
);
@@ -472,7 +480,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
472480
constraint_word_addition::<_, C>(
473481
// Note: here we can't do a conditional check because the degree of sum is already 3
474482
&mut builder.when_transition(),
475-
&[&small_sig1_field::<AB::Expr>(
483+
&[&small_sig1_field::<AB::Expr, C>(
476484
w.row(i + 2).as_slice().unwrap(),
477485
)],
478486
&[&w_7, intermed_16.as_slice().unwrap()],
@@ -481,13 +489,13 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
481489
);
482490

483491
for j in 0..C::WORD_U16S {
484-
// When on rows 4..16 message schedule carries should be 0 or 1
485-
let is_row_4_15 = *next.flags.is_round_row - *next.flags.is_first_4_rows;
492+
// When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1
493+
let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows;
486494
builder
487-
.when(is_row_4_15.clone())
495+
.when(is_row_4_or_more.clone())
488496
.assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]);
489497
builder
490-
.when(is_row_4_15)
498+
.when(is_row_4_or_more)
491499
.assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]);
492500
// Constrain w being composed of bits
493501
for j in 0..C::WORD_BITS {
@@ -499,7 +507,7 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
499507
}
500508
}
501509

502-
/// Constrain the work vars on `next` row according to the sha256 documentation
510+
/// Constrain the work vars on `next` row according to the sha documentation
503511
/// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf]
504512
fn eval_work_vars<'a, AB: InteractionBuilder>(
505513
&self,
@@ -536,11 +544,12 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
536544
) * *next.flags.is_round_row
537545
})
538546
.collect::<Vec<_>>();
547+
539548
let k_limbs = (0..C::WORD_U16S)
540549
.map(|j| {
541550
self.row_idx_encoder.flag_with_val::<AB>(
542551
next.flags.row_idx.to_slice().unwrap(),
543-
&(0..16)
552+
&(0..C::ROUND_ROWS)
544553
.map(|rw_idx| {
545554
(
546555
rw_idx,
@@ -559,13 +568,13 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
559568
builder,
560569
&[
561570
e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
562-
&big_sig1_field::<AB::Expr>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
571+
&big_sig1_field::<AB::Expr, C>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
563572
&ch_field::<AB::Expr>(
564573
e.row(i + 3).as_slice().unwrap(),
565574
e.row(i + 2).as_slice().unwrap(),
566575
e.row(i + 1).as_slice().unwrap(),
567576
), // Ch of previous `e`, `f`, `g`
568-
&big_sig0_field::<AB::Expr>(a.row(i + 3).as_slice().unwrap()), // sig_0 of previous `a`
577+
&big_sig0_field::<AB::Expr, C>(a.row(i + 3).as_slice().unwrap()), // sig_0 of previous `a`
569578
&maj_field::<AB::Expr>(
570579
a.row(i + 3).as_slice().unwrap(),
571580
a.row(i + 2).as_slice().unwrap(),
@@ -581,9 +590,9 @@ impl<C: ShaConfig + ShaPrecomputedValues<C::Word>> ShaAir<C> {
581590
constraint_word_addition::<_, C>(
582591
builder,
583592
&[
584-
&a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d`
585-
&e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
586-
&big_sig1_field::<AB::Expr>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
593+
a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d`
594+
e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h`
595+
&big_sig1_field::<AB::Expr, C>(e.row(i + 3).as_slice().unwrap()), // sig_1 of previous `e`
587596
&ch_field::<AB::Expr>(
588597
e.row(i + 3).as_slice().unwrap(),
589598
e.row(i + 2).as_slice().unwrap(),

crates/circuits/sha256-air/src/columns.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ pub struct ShaMessageScheduleCols<
7171
> {
7272
/// The message schedule words as 32-bit intergers
7373
pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW],
74-
/// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used freely by wrapper chips
74+
/// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be used freely by wrapper chips
7575
/// Note: carries are represented as 2 bit numbers
7676
pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW],
7777
}

crates/circuits/sha256-air/src/config.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ pub trait ShaConfig: Send + Sync + Clone {
3030
const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS;
3131
/// Number of rows per block
3232
const ROWS_PER_BLOCK: usize;
33-
/// Number of rounds per row
33+
/// Number of rounds per row (must divide BLOCK_WORDS)
3434
const ROUNDS_PER_ROW: usize;
35+
/// Number of rows used for the sha rounds
36+
const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW;
37+
/// Number of rows used for the message
38+
const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW;
3539
/// Number of rounds per row minus one (needed for one of the column structs)
3640
const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1;
3741
/// Number of rounds per block
@@ -119,6 +123,7 @@ pub const SHA256_INVALID_CARRY_E: [[u32; Sha256Config::WORD_U16S]; Sha256Config:
119123
[719953922, 1888246508],
120124
[194580482, 1075725211],
121125
];
126+
122127
/// SHA256 constant K's
123128
pub const SHA256_K: [u32; 64] = [
124129
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
@@ -161,25 +166,34 @@ impl ShaConfig for Sha512Config {
161166
/// Number of words in a SHA512 block
162167
const BLOCK_WORDS: usize = 16;
163168
/// Number of rows per block
164-
const ROWS_PER_BLOCK: usize = 21; // SHA-512 has 80 rounds, so needs more rows
169+
const ROWS_PER_BLOCK: usize = 21;
165170
/// Number of rounds per row
166171
const ROUNDS_PER_ROW: usize = 4;
167172
/// Number of rounds per block
168173
const ROUNDS_PER_BLOCK: usize = 80;
169174
/// Number of words in a SHA512 hash
170175
const HASH_WORDS: usize = 8;
171176
/// Number of vars needed to encode the row index with [Encoder]
172-
const ROW_VAR_CNT: usize = 5;
177+
const ROW_VAR_CNT: usize = 6;
173178
}
174179

175-
// TODO: fill in these constants
176-
177180
/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows
178181
/// To optimize the trace generation of invalid rows, we have those values precomputed here
179182
pub(crate) const SHA512_INVALID_CARRY_A: [[u32; Sha512Config::WORD_U16S];
180-
Sha512Config::ROUNDS_PER_ROW] = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]];
183+
Sha512Config::ROUNDS_PER_ROW] = [
184+
[55971842, 827997017, 993005918, 512731953],
185+
[227512322, 1697529235, 1936430385, 940122990],
186+
[1939875843, 1173318562, 826201586, 1513494849],
187+
[891955202, 1732283693, 1736658755, 223514501],
188+
];
189+
181190
pub(crate) const SHA512_INVALID_CARRY_E: [[u32; Sha512Config::WORD_U16S];
182-
Sha512Config::ROUNDS_PER_ROW] = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]];
191+
Sha512Config::ROUNDS_PER_ROW] = [
192+
[1384427522, 1509509767, 153131516, 102514978],
193+
[1527552003, 1041677071, 837289497, 843522538],
194+
[775188482, 1620184630, 744892564, 892058728],
195+
[1801267202, 1393118048, 1846108940, 830635531],
196+
];
183197

184198
/// SHA512 constant K's
185199
pub const SHA512_K: [u64; 80] = [

crates/circuits/sha256-air/src/tests.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{array, cmp::max, sync::Arc};
1+
use std::{cmp::max, sync::Arc};
22

33
use openvm_circuit::arch::{
44
instructions::riscv::RV32_CELL_BITS,
@@ -12,19 +12,15 @@ use openvm_stark_backend::{
1212
config::{StarkGenericConfig, Val},
1313
interaction::InteractionBuilder,
1414
p3_air::{Air, BaseAir},
15-
p3_field::{Field, FieldAlgebra, PrimeField32},
15+
p3_field::{Field, PrimeField32},
1616
prover::types::AirProofInput,
1717
rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
1818
AirRef, Chip, ChipUsageGetter,
1919
};
20-
use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng};
20+
use openvm_stark_sdk::utils::create_seeded_rng;
2121
use rand::Rng;
22-
use sha2::Sha256;
2322

24-
use crate::{
25-
Sha256Config, Sha512Config, ShaAir, ShaConfig, ShaFlagsColsRef, ShaFlagsColsRefMut,
26-
ShaPrecomputedValues,
27-
};
23+
use crate::{Sha256Config, Sha512Config, ShaAir, ShaConfig, ShaPrecomputedValues};
2824

2925
// A wrapper AIR purely for testing purposes
3026
#[derive(Clone, Debug)]
@@ -126,3 +122,8 @@ fn rand_sha_test<C: ShaConfig + ShaPrecomputedValues<C::Word> + 'static>() {
126122
fn rand_sha256_test() {
127123
rand_sha_test::<Sha256Config>();
128124
}
125+
126+
#[test]
127+
fn rand_sha512_test() {
128+
rand_sha_test::<Sha512Config>();
129+
}

0 commit comments

Comments
 (0)