Skip to content

Commit 378e076

Browse files
feat: modify fri records + tests (#1819)
Towards INT-4341 This PR is blocking the related PR in axiom-gpu Made some changes to `NativeFri` - In `WorkLoadRecord` instead of `b` now I keep keep the current `result`. `b` can be computed using `a`, `alpha`, `result`, and `previous_result`. This allows for tracegen to skip a serial pass calculating all the intermediate `results` at the start. - Separated `a_prev_data` from `WrokLoadRecord` so we can skip allocating memory for it when a setup instruction. This is an optimization that reduces record sizes since we don't always need the `prev_data` - Ported the test file to the new framework - divided the WorkloadRows tracegen into two parts. One lighter serial pass over all the `WorkLoadRows` that writes the records into the corresponding places. And a second parallel pass that fills the rest of the trace. - Also, there was a small bug related to not zeroing out the `a_prev_data` when no write happened
1 parent 0ed7820 commit 378e076

File tree

2 files changed

+220
-190
lines changed

2 files changed

+220
-190
lines changed

extensions/native/circuit/src/fri/mod.rs

Lines changed: 111 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ use openvm_stark_backend::{
3636
p3_air::{Air, AirBuilder, BaseAir},
3737
p3_field::{Field, FieldAlgebra, PrimeField32},
3838
p3_matrix::{dense::RowMajorMatrix, Matrix},
39-
p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
39+
p3_maybe_rayon::prelude::{
40+
IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSliceMut,
41+
},
4042
rap::{BaseAirWithPublicValues, PartitionedBaseAir},
4143
};
4244
use static_assertions::const_assert_eq;
@@ -552,6 +554,7 @@ fn elem_to_ext<F: Field>(elem: F) -> [F; EXT_DEG] {
552554
#[derive(Copy, Clone, Debug)]
553555
pub struct FriReducedOpeningMetadata {
554556
length: usize,
557+
is_init: bool,
555558
}
556559

557560
impl MultiRowMetadata for FriReducedOpeningMetadata {
@@ -569,6 +572,7 @@ type FriReducedOpeningLayout = MultiRowLayout<FriReducedOpeningMetadata>;
569572
#[derive(AlignedBytesBorrow, Debug)]
570573
pub struct FriReducedOpeningHeaderRecord {
571574
pub length: u32,
575+
pub is_init: bool,
572576
}
573577

574578
// Part of record that is common for all trace rows for an instruction
@@ -578,11 +582,9 @@ pub struct FriReducedOpeningHeaderRecord {
578582
pub struct FriReducedOpeningCommonRecord<F> {
579583
pub timestamp: u32,
580584

581-
pub a_ptr: F,
585+
pub a_ptr: u32,
582586

583-
pub is_init: bool,
584-
585-
pub b_ptr: F,
587+
pub b_ptr: u32,
586588

587589
pub alpha: [F; EXT_DEG],
588590

@@ -615,8 +617,11 @@ pub struct FriReducedOpeningCommonRecord<F> {
615617
#[derive(AlignedBytesBorrow, Debug)]
616618
pub struct FriReducedOpeningWorkloadRowRecord<F> {
617619
pub a: F,
618-
pub a_aux: MemoryWriteAuxRecord<F, 1>,
619-
pub b: [F; EXT_DEG],
620+
pub a_aux: MemoryReadAuxRecord,
621+
// The result of this workload row
622+
// b can be computed from a, alpha, result, and previous result:
623+
// b = result + a - prev_result * alpha
624+
pub result: [F; EXT_DEG],
620625
pub b_aux: MemoryReadAuxRecord,
621626
}
622627

@@ -625,6 +630,9 @@ pub struct FriReducedOpeningWorkloadRowRecord<F> {
625630
pub struct FriReducedOpeningRecordMut<'a, F> {
626631
pub header: &'a mut FriReducedOpeningHeaderRecord,
627632
pub workload: &'a mut [FriReducedOpeningWorkloadRowRecord<F>],
633+
// if is_init this will be an empty slice, otherwise it will be the previous data of writing
634+
// `a`s
635+
pub a_write_prev_data: &'a mut [F],
628636
pub common: &'a mut FriReducedOpeningCommonRecord<F>,
629637
}
630638

@@ -641,8 +649,17 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
641649

642650
let workload_size =
643651
layout.metadata.length * size_of::<FriReducedOpeningWorkloadRowRecord<F>>();
644-
let (workload_buf, common_buf) = unsafe { rest.split_at_mut_unchecked(workload_size) };
645652

653+
let (workload_buf, rest) = unsafe { rest.split_at_mut_unchecked(workload_size) };
654+
let a_prev_size = if layout.metadata.is_init {
655+
0
656+
} else {
657+
layout.metadata.length * size_of::<F>()
658+
};
659+
660+
let (a_prev_buf, common_buf) = unsafe { rest.split_at_mut_unchecked(a_prev_size) };
661+
662+
let (_, a_prev_records, _) = unsafe { a_prev_buf.align_to_mut::<F>() };
646663
let (_, workload_records, _) =
647664
unsafe { workload_buf.align_to_mut::<FriReducedOpeningWorkloadRowRecord<F>>() };
648665

@@ -651,6 +668,7 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
651668
FriReducedOpeningRecordMut {
652669
header,
653670
workload: &mut workload_records[..layout.metadata.length],
671+
a_write_prev_data: &mut a_prev_records[..],
654672
common,
655673
}
656674
}
@@ -659,6 +677,7 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
659677
let header: &FriReducedOpeningHeaderRecord = self.borrow();
660678
FriReducedOpeningLayout::new(FriReducedOpeningMetadata {
661679
length: header.length as usize,
680+
is_init: header.is_init,
662681
})
663682
}
664683
}
@@ -732,9 +751,13 @@ where
732751
let length_ptr = c.as_canonical_u32();
733752
let [length]: [F; 1] = memory_read_native(&state.memory.data, length_ptr);
734753
let length = length.as_canonical_u32();
754+
let is_init_ptr = g.as_canonical_u32();
755+
let [is_init]: [F; 1] = memory_read_native(&state.memory.data, is_init_ptr);
756+
let is_init = is_init != F::ZERO;
735757

736758
let metadata = FriReducedOpeningMetadata {
737759
length: length as usize,
760+
is_init,
738761
};
739762
let record = arena.alloc(MultiRowLayout::new(metadata));
740763

@@ -765,7 +788,7 @@ where
765788
&mut record.common.a_ptr_aux.prev_timestamp,
766789
);
767790
record.common.a_ptr_ptr = a;
768-
record.common.a_ptr = a_ptr;
791+
record.common.a_ptr = a_ptr.as_canonical_u32();
769792

770793
let b_ptr_ptr = b.as_canonical_u32();
771794
let [b_ptr]: [F; 1] = tracing_read_native(
@@ -774,17 +797,15 @@ where
774797
&mut record.common.b_ptr_aux.prev_timestamp,
775798
);
776799
record.common.b_ptr_ptr = b;
777-
record.common.b_ptr = b_ptr;
800+
record.common.b_ptr = b_ptr.as_canonical_u32();
778801

779-
let is_init_ptr = g.as_canonical_u32();
780-
let [is_init]: [F; 1] = tracing_read_native(
802+
tracing_read_native::<F, 1>(
781803
state.memory,
782804
is_init_ptr,
783805
&mut record.common.is_init_aux.prev_timestamp,
784806
);
785-
let is_init = is_init != F::ZERO;
786807
record.common.is_init_ptr = g;
787-
record.common.is_init = is_init;
808+
record.header.is_init = is_init;
788809

789810
let hint_id_ptr = f.as_canonical_u32();
790811
let [hint_id]: [F; 1] = memory_read_native(state.memory.data(), hint_id_ptr);
@@ -805,15 +826,17 @@ where
805826
for i in 0..length {
806827
let workload_row = &mut record.workload[length - i - 1];
807828

808-
let a_ptr_i = (a_ptr + F::from_canonical_usize(i)).as_canonical_u32();
829+
let a_ptr_i = record.common.a_ptr + i as u32;
809830
let [a]: [F; 1] = if !is_init {
831+
let mut prev = [F::ZERO; 1];
810832
tracing_write_native(
811833
state.memory,
812834
a_ptr_i,
813835
[data[i]],
814836
&mut workload_row.a_aux.prev_timestamp,
815-
&mut workload_row.a_aux.prev_data,
837+
&mut prev,
816838
);
839+
record.a_write_prev_data[length - i - 1] = prev[0];
817840
[data[i]]
818841
} else {
819842
tracing_read_native(
@@ -822,7 +845,7 @@ where
822845
&mut workload_row.a_aux.prev_timestamp,
823846
)
824847
};
825-
let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32();
848+
let b_ptr_i = record.common.b_ptr + (EXT_DEG * i) as u32;
826849
let b = tracing_read_native::<F, EXT_DEG>(
827850
state.memory,
828851
b_ptr_i,
@@ -836,14 +859,13 @@ where
836859
for (i, (a, b)) in as_and_bs.into_iter().rev().enumerate() {
837860
let workload_row = &mut record.workload[i];
838861

839-
workload_row.a = a;
840-
workload_row.b = b;
841-
842862
// result = result * alpha + (b - a)
843863
result = FieldExtension::add(
844864
FieldExtension::multiply(result, alpha),
845865
FieldExtension::subtract(b, elem_to_ext(a)),
846866
);
867+
workload_row.a = a;
868+
workload_row.result = result;
847869
}
848870

849871
let result_ptr = e.as_canonical_u32();
@@ -887,22 +909,23 @@ where
887909
let num_rows = header.length as usize + 2;
888910
let chunk_size = OVERALL_WIDTH * num_rows;
889911
let (chunk, rest) = remaining_trace.split_at_mut(chunk_size);
890-
chunks.push(chunk);
912+
chunks.push((chunk, header.is_init));
891913
remaining_trace = rest;
892914
}
893915

894-
chunks.into_par_iter().for_each(|mut chunk| {
916+
chunks.into_par_iter().for_each(|(mut chunk, is_init)| {
895917
let num_rows = chunk.len() / OVERALL_WIDTH;
896918
let metadata = FriReducedOpeningMetadata {
897919
length: num_rows - 2,
920+
is_init,
898921
};
899922
let record: FriReducedOpeningRecordMut<F> =
900923
unsafe { get_record_from_slice(&mut chunk, MultiRowLayout::new(metadata)) };
901924

902925
let timestamp = record.common.timestamp;
903926
let length = record.header.length as usize;
904927
let alpha = record.common.alpha;
905-
let is_init = record.common.is_init;
928+
let is_init = record.header.is_init;
906929
let write_a = F::from_bool(!is_init);
907930

908931
let a_ptr = record.common.a_ptr;
@@ -911,23 +934,6 @@ where
911934
let (workload_chunk, rest) = chunk.split_at_mut(length * OVERALL_WIDTH);
912935
let (ins1_chunk, ins2_chunk) = rest.split_at_mut(OVERALL_WIDTH);
913936

914-
let mut results: Vec<[F; EXT_DEG]> =
915-
std::iter::once([F::ZERO; EXT_DEG])
916-
.chain(record.workload.iter().scan(
917-
[F::ZERO; EXT_DEG],
918-
|result, workload_row| {
919-
let a = workload_row.a;
920-
let b = workload_row.b;
921-
922-
*result = FieldExtension::add(
923-
FieldExtension::multiply(*result, alpha),
924-
FieldExtension::subtract(b, elem_to_ext(a)),
925-
);
926-
Some(*result)
927-
},
928-
))
929-
.collect();
930-
931937
{
932938
// ins2 row
933939
let cols: &mut Instruction2Cols<F> = ins2_chunk[..INS_2_WIDTH].borrow_mut();
@@ -998,31 +1004,43 @@ where
9981004
cols.pc = F::from_canonical_u32(record.common.from_pc);
9991005

10001006
cols.prefix.data.alpha = alpha;
1001-
cols.prefix.data.result = results.pop().unwrap();
1007+
cols.prefix.data.result = record.workload.last().unwrap().result;
10021008
cols.prefix.data.idx = F::from_canonical_usize(length);
1003-
cols.prefix.data.b_ptr = b_ptr;
1009+
cols.prefix.data.b_ptr = F::from_canonical_u32(b_ptr);
10041010
cols.prefix.data.write_a = write_a;
1005-
cols.prefix.data.a_ptr = a_ptr;
1011+
cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr);
10061012

10071013
cols.prefix.a_or_is_first = F::ONE;
10081014

10091015
cols.prefix.general.timestamp = F::from_canonical_u32(timestamp);
10101016
cols.prefix.general.is_ins_row = F::ONE;
10111017
cols.prefix.general.is_workload_row = F::ZERO;
1012-
10131018
ins1_chunk[INS_1_WIDTH..OVERALL_WIDTH].fill(F::ZERO);
10141019
}
10151020

1016-
for (i, (workload_row, result)) in record
1021+
// To fill the WorkloadRows we do 2 passes:
1022+
// - First, a serial pass to fill some of the records into the trace
1023+
// - Then, a parallel pass to fill the rest of the records into the trace
1024+
// Note, the first pass is done to avoid overwriting the records
1025+
1026+
// Copy of `a_write_prev_data` to avoid overwriting it and to use it in the parallel
1027+
// pass
1028+
let a_prev_data = if !is_init {
1029+
let mut tmp = Vec::with_capacity(length);
1030+
tmp.extend_from_slice(record.a_write_prev_data);
1031+
tmp
1032+
} else {
1033+
vec![]
1034+
};
1035+
1036+
for (i, (workload_row, row_chunk)) in record
10171037
.workload
10181038
.iter()
1019-
.zip(results.into_iter())
1039+
.zip(workload_chunk.chunks_exact_mut(OVERALL_WIDTH))
10201040
.enumerate()
10211041
.rev()
10221042
{
1023-
let offset = i * OVERALL_WIDTH;
1024-
let cols: &mut WorkloadCols<F> =
1025-
workload_chunk[offset..offset + WL_WIDTH].borrow_mut();
1043+
let cols: &mut WorkloadCols<F> = row_chunk[..WL_WIDTH].borrow_mut();
10261044

10271045
let timestamp = timestamp + ((length - i) * 2) as u32;
10281046

@@ -1032,32 +1050,59 @@ where
10321050
timestamp + 4,
10331051
cols.b_aux.as_mut(),
10341052
);
1035-
cols.b = workload_row.b;
10361053

1037-
if !is_init {
1038-
cols.a_aux.set_prev_data(workload_row.a_aux.prev_data);
1039-
}
1054+
// We temporarily store the result here
1055+
// the correct value of b is computed during the serial pass below
1056+
cols.b = record.workload[i].result;
1057+
10401058
mem_helper.fill(
10411059
workload_row.a_aux.prev_timestamp,
10421060
timestamp + 3,
10431061
cols.a_aux.as_mut(),
10441062
);
1045-
1046-
cols.prefix.data.alpha = alpha;
1047-
cols.prefix.data.result = result;
1048-
cols.prefix.data.idx = F::from_canonical_usize(i);
1049-
cols.prefix.data.b_ptr = b_ptr + F::from_canonical_usize((length - i) * EXT_DEG);
1050-
cols.prefix.data.write_a = write_a;
1051-
cols.prefix.data.a_ptr = a_ptr + F::from_canonical_usize(length - i);
1052-
10531063
cols.prefix.a_or_is_first = workload_row.a;
10541064

1055-
cols.prefix.general.timestamp = F::from_canonical_u32(timestamp);
1056-
cols.prefix.general.is_ins_row = F::ZERO;
1057-
cols.prefix.general.is_workload_row = F::ONE;
1058-
1059-
workload_chunk[offset + WL_WIDTH..offset + OVERALL_WIDTH].fill(F::ZERO);
1065+
if i > 0 {
1066+
cols.prefix.data.result = record.workload[i - 1].result;
1067+
}
10601068
}
1069+
1070+
workload_chunk
1071+
.par_chunks_exact_mut(OVERALL_WIDTH)
1072+
.enumerate()
1073+
.for_each(|(i, row_chunk)| {
1074+
let cols: &mut WorkloadCols<F> = row_chunk[..WL_WIDTH].borrow_mut();
1075+
let timestamp = timestamp + ((length - i) * 2) as u32;
1076+
if is_init {
1077+
cols.a_aux.set_prev_data([F::ZERO; 1]);
1078+
} else {
1079+
cols.a_aux.set_prev_data([a_prev_data[i]]);
1080+
}
1081+
1082+
// DataCols
1083+
cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr + (length - i) as u32);
1084+
cols.prefix.data.write_a = write_a;
1085+
cols.prefix.data.b_ptr =
1086+
F::from_canonical_u32(b_ptr + ((length - i) * EXT_DEG) as u32);
1087+
cols.prefix.data.idx = F::from_canonical_usize(i);
1088+
if i == 0 {
1089+
cols.prefix.data.result = [F::ZERO; EXT_DEG];
1090+
}
1091+
cols.prefix.data.alpha = alpha;
1092+
1093+
// GeneralCols
1094+
cols.prefix.general.is_workload_row = F::ONE;
1095+
cols.prefix.general.is_ins_row = F::ZERO;
1096+
1097+
// WorkloadCols
1098+
cols.prefix.general.timestamp = F::from_canonical_u32(timestamp);
1099+
1100+
cols.b = FieldExtension::subtract(
1101+
FieldExtension::add(cols.b, elem_to_ext(cols.prefix.a_or_is_first)),
1102+
FieldExtension::multiply(cols.prefix.data.result, alpha),
1103+
);
1104+
row_chunk[WL_WIDTH..OVERALL_WIDTH].fill(F::ZERO);
1105+
});
10611106
});
10621107
}
10631108
}

0 commit comments

Comments
 (0)