@@ -36,7 +36,9 @@ use openvm_stark_backend::{
36
36
p3_air:: { Air , AirBuilder , BaseAir } ,
37
37
p3_field:: { Field , FieldAlgebra , PrimeField32 } ,
38
38
p3_matrix:: { dense:: RowMajorMatrix , Matrix } ,
39
- p3_maybe_rayon:: prelude:: { IntoParallelIterator , ParallelIterator } ,
39
+ p3_maybe_rayon:: prelude:: {
40
+ IndexedParallelIterator , IntoParallelIterator , ParallelIterator , ParallelSliceMut ,
41
+ } ,
40
42
rap:: { BaseAirWithPublicValues , PartitionedBaseAir } ,
41
43
} ;
42
44
use static_assertions:: const_assert_eq;
@@ -552,6 +554,7 @@ fn elem_to_ext<F: Field>(elem: F) -> [F; EXT_DEG] {
552
554
#[ derive( Copy , Clone , Debug ) ]
553
555
pub struct FriReducedOpeningMetadata {
554
556
length : usize ,
557
+ is_init : bool ,
555
558
}
556
559
557
560
impl MultiRowMetadata for FriReducedOpeningMetadata {
@@ -569,6 +572,7 @@ type FriReducedOpeningLayout = MultiRowLayout<FriReducedOpeningMetadata>;
569
572
#[ derive( AlignedBytesBorrow , Debug ) ]
570
573
pub struct FriReducedOpeningHeaderRecord {
571
574
pub length : u32 ,
575
+ pub is_init : bool ,
572
576
}
573
577
574
578
// Part of record that is common for all trace rows for an instruction
@@ -578,11 +582,9 @@ pub struct FriReducedOpeningHeaderRecord {
578
582
pub struct FriReducedOpeningCommonRecord < F > {
579
583
pub timestamp : u32 ,
580
584
581
- pub a_ptr : F ,
585
+ pub a_ptr : u32 ,
582
586
583
- pub is_init : bool ,
584
-
585
- pub b_ptr : F ,
587
+ pub b_ptr : u32 ,
586
588
587
589
pub alpha : [ F ; EXT_DEG ] ,
588
590
@@ -615,8 +617,11 @@ pub struct FriReducedOpeningCommonRecord<F> {
615
617
#[ derive( AlignedBytesBorrow , Debug ) ]
616
618
pub struct FriReducedOpeningWorkloadRowRecord < F > {
617
619
pub a : F ,
618
- pub a_aux : MemoryWriteAuxRecord < F , 1 > ,
619
- pub b : [ F ; EXT_DEG ] ,
620
+ pub a_aux : MemoryReadAuxRecord ,
621
+ // The result of this workload row
622
+ // b can be computed from a, alpha, result, and previous result:
623
+ // b = result + a - prev_result * alpha
624
+ pub result : [ F ; EXT_DEG ] ,
620
625
pub b_aux : MemoryReadAuxRecord ,
621
626
}
622
627
@@ -625,6 +630,9 @@ pub struct FriReducedOpeningWorkloadRowRecord<F> {
625
630
pub struct FriReducedOpeningRecordMut < ' a , F > {
626
631
pub header : & ' a mut FriReducedOpeningHeaderRecord ,
627
632
pub workload : & ' a mut [ FriReducedOpeningWorkloadRowRecord < F > ] ,
633
+ // if is_init this will be an empty slice, otherwise it will be the previous data of writing
634
+ // `a`s
635
+ pub a_write_prev_data : & ' a mut [ F ] ,
628
636
pub common : & ' a mut FriReducedOpeningCommonRecord < F > ,
629
637
}
630
638
@@ -641,8 +649,17 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
641
649
642
650
let workload_size =
643
651
layout. metadata . length * size_of :: < FriReducedOpeningWorkloadRowRecord < F > > ( ) ;
644
- let ( workload_buf, common_buf) = unsafe { rest. split_at_mut_unchecked ( workload_size) } ;
645
652
653
+ let ( workload_buf, rest) = unsafe { rest. split_at_mut_unchecked ( workload_size) } ;
654
+ let a_prev_size = if layout. metadata . is_init {
655
+ 0
656
+ } else {
657
+ layout. metadata . length * size_of :: < F > ( )
658
+ } ;
659
+
660
+ let ( a_prev_buf, common_buf) = unsafe { rest. split_at_mut_unchecked ( a_prev_size) } ;
661
+
662
+ let ( _, a_prev_records, _) = unsafe { a_prev_buf. align_to_mut :: < F > ( ) } ;
646
663
let ( _, workload_records, _) =
647
664
unsafe { workload_buf. align_to_mut :: < FriReducedOpeningWorkloadRowRecord < F > > ( ) } ;
648
665
@@ -651,6 +668,7 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
651
668
FriReducedOpeningRecordMut {
652
669
header,
653
670
workload : & mut workload_records[ ..layout. metadata . length ] ,
671
+ a_write_prev_data : & mut a_prev_records[ ..] ,
654
672
common,
655
673
}
656
674
}
@@ -659,6 +677,7 @@ impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpenin
659
677
let header: & FriReducedOpeningHeaderRecord = self . borrow ( ) ;
660
678
FriReducedOpeningLayout :: new ( FriReducedOpeningMetadata {
661
679
length : header. length as usize ,
680
+ is_init : header. is_init ,
662
681
} )
663
682
}
664
683
}
@@ -732,9 +751,13 @@ where
732
751
let length_ptr = c. as_canonical_u32 ( ) ;
733
752
let [ length] : [ F ; 1 ] = memory_read_native ( & state. memory . data , length_ptr) ;
734
753
let length = length. as_canonical_u32 ( ) ;
754
+ let is_init_ptr = g. as_canonical_u32 ( ) ;
755
+ let [ is_init] : [ F ; 1 ] = memory_read_native ( & state. memory . data , is_init_ptr) ;
756
+ let is_init = is_init != F :: ZERO ;
735
757
736
758
let metadata = FriReducedOpeningMetadata {
737
759
length : length as usize ,
760
+ is_init,
738
761
} ;
739
762
let record = arena. alloc ( MultiRowLayout :: new ( metadata) ) ;
740
763
@@ -765,7 +788,7 @@ where
765
788
& mut record. common . a_ptr_aux . prev_timestamp ,
766
789
) ;
767
790
record. common . a_ptr_ptr = a;
768
- record. common . a_ptr = a_ptr;
791
+ record. common . a_ptr = a_ptr. as_canonical_u32 ( ) ;
769
792
770
793
let b_ptr_ptr = b. as_canonical_u32 ( ) ;
771
794
let [ b_ptr] : [ F ; 1 ] = tracing_read_native (
@@ -774,17 +797,15 @@ where
774
797
& mut record. common . b_ptr_aux . prev_timestamp ,
775
798
) ;
776
799
record. common . b_ptr_ptr = b;
777
- record. common . b_ptr = b_ptr;
800
+ record. common . b_ptr = b_ptr. as_canonical_u32 ( ) ;
778
801
779
- let is_init_ptr = g. as_canonical_u32 ( ) ;
780
- let [ is_init] : [ F ; 1 ] = tracing_read_native (
802
+ tracing_read_native :: < F , 1 > (
781
803
state. memory ,
782
804
is_init_ptr,
783
805
& mut record. common . is_init_aux . prev_timestamp ,
784
806
) ;
785
- let is_init = is_init != F :: ZERO ;
786
807
record. common . is_init_ptr = g;
787
- record. common . is_init = is_init;
808
+ record. header . is_init = is_init;
788
809
789
810
let hint_id_ptr = f. as_canonical_u32 ( ) ;
790
811
let [ hint_id] : [ F ; 1 ] = memory_read_native ( state. memory . data ( ) , hint_id_ptr) ;
@@ -805,15 +826,17 @@ where
805
826
for i in 0 ..length {
806
827
let workload_row = & mut record. workload [ length - i - 1 ] ;
807
828
808
- let a_ptr_i = ( a_ptr + F :: from_canonical_usize ( i ) ) . as_canonical_u32 ( ) ;
829
+ let a_ptr_i = record . common . a_ptr + i as u32 ;
809
830
let [ a] : [ F ; 1 ] = if !is_init {
831
+ let mut prev = [ F :: ZERO ; 1 ] ;
810
832
tracing_write_native (
811
833
state. memory ,
812
834
a_ptr_i,
813
835
[ data[ i] ] ,
814
836
& mut workload_row. a_aux . prev_timestamp ,
815
- & mut workload_row . a_aux . prev_data ,
837
+ & mut prev ,
816
838
) ;
839
+ record. a_write_prev_data [ length - i - 1 ] = prev[ 0 ] ;
817
840
[ data[ i] ]
818
841
} else {
819
842
tracing_read_native (
@@ -822,7 +845,7 @@ where
822
845
& mut workload_row. a_aux . prev_timestamp ,
823
846
)
824
847
} ;
825
- let b_ptr_i = ( b_ptr + F :: from_canonical_usize ( EXT_DEG * i) ) . as_canonical_u32 ( ) ;
848
+ let b_ptr_i = record . common . b_ptr + ( EXT_DEG * i) as u32 ;
826
849
let b = tracing_read_native :: < F , EXT_DEG > (
827
850
state. memory ,
828
851
b_ptr_i,
@@ -836,14 +859,13 @@ where
836
859
for ( i, ( a, b) ) in as_and_bs. into_iter ( ) . rev ( ) . enumerate ( ) {
837
860
let workload_row = & mut record. workload [ i] ;
838
861
839
- workload_row. a = a;
840
- workload_row. b = b;
841
-
842
862
// result = result * alpha + (b - a)
843
863
result = FieldExtension :: add (
844
864
FieldExtension :: multiply ( result, alpha) ,
845
865
FieldExtension :: subtract ( b, elem_to_ext ( a) ) ,
846
866
) ;
867
+ workload_row. a = a;
868
+ workload_row. result = result;
847
869
}
848
870
849
871
let result_ptr = e. as_canonical_u32 ( ) ;
@@ -887,22 +909,23 @@ where
887
909
let num_rows = header. length as usize + 2 ;
888
910
let chunk_size = OVERALL_WIDTH * num_rows;
889
911
let ( chunk, rest) = remaining_trace. split_at_mut ( chunk_size) ;
890
- chunks. push ( chunk) ;
912
+ chunks. push ( ( chunk, header . is_init ) ) ;
891
913
remaining_trace = rest;
892
914
}
893
915
894
- chunks. into_par_iter ( ) . for_each ( |mut chunk| {
916
+ chunks. into_par_iter ( ) . for_each ( |( mut chunk, is_init ) | {
895
917
let num_rows = chunk. len ( ) / OVERALL_WIDTH ;
896
918
let metadata = FriReducedOpeningMetadata {
897
919
length : num_rows - 2 ,
920
+ is_init,
898
921
} ;
899
922
let record: FriReducedOpeningRecordMut < F > =
900
923
unsafe { get_record_from_slice ( & mut chunk, MultiRowLayout :: new ( metadata) ) } ;
901
924
902
925
let timestamp = record. common . timestamp ;
903
926
let length = record. header . length as usize ;
904
927
let alpha = record. common . alpha ;
905
- let is_init = record. common . is_init ;
928
+ let is_init = record. header . is_init ;
906
929
let write_a = F :: from_bool ( !is_init) ;
907
930
908
931
let a_ptr = record. common . a_ptr ;
@@ -911,23 +934,6 @@ where
911
934
let ( workload_chunk, rest) = chunk. split_at_mut ( length * OVERALL_WIDTH ) ;
912
935
let ( ins1_chunk, ins2_chunk) = rest. split_at_mut ( OVERALL_WIDTH ) ;
913
936
914
- let mut results: Vec < [ F ; EXT_DEG ] > =
915
- std:: iter:: once ( [ F :: ZERO ; EXT_DEG ] )
916
- . chain ( record. workload . iter ( ) . scan (
917
- [ F :: ZERO ; EXT_DEG ] ,
918
- |result, workload_row| {
919
- let a = workload_row. a ;
920
- let b = workload_row. b ;
921
-
922
- * result = FieldExtension :: add (
923
- FieldExtension :: multiply ( * result, alpha) ,
924
- FieldExtension :: subtract ( b, elem_to_ext ( a) ) ,
925
- ) ;
926
- Some ( * result)
927
- } ,
928
- ) )
929
- . collect ( ) ;
930
-
931
937
{
932
938
// ins2 row
933
939
let cols: & mut Instruction2Cols < F > = ins2_chunk[ ..INS_2_WIDTH ] . borrow_mut ( ) ;
@@ -998,31 +1004,43 @@ where
998
1004
cols. pc = F :: from_canonical_u32 ( record. common . from_pc ) ;
999
1005
1000
1006
cols. prefix . data . alpha = alpha;
1001
- cols. prefix . data . result = results . pop ( ) . unwrap ( ) ;
1007
+ cols. prefix . data . result = record . workload . last ( ) . unwrap ( ) . result ;
1002
1008
cols. prefix . data . idx = F :: from_canonical_usize ( length) ;
1003
- cols. prefix . data . b_ptr = b_ptr;
1009
+ cols. prefix . data . b_ptr = F :: from_canonical_u32 ( b_ptr) ;
1004
1010
cols. prefix . data . write_a = write_a;
1005
- cols. prefix . data . a_ptr = a_ptr;
1011
+ cols. prefix . data . a_ptr = F :: from_canonical_u32 ( a_ptr) ;
1006
1012
1007
1013
cols. prefix . a_or_is_first = F :: ONE ;
1008
1014
1009
1015
cols. prefix . general . timestamp = F :: from_canonical_u32 ( timestamp) ;
1010
1016
cols. prefix . general . is_ins_row = F :: ONE ;
1011
1017
cols. prefix . general . is_workload_row = F :: ZERO ;
1012
-
1013
1018
ins1_chunk[ INS_1_WIDTH ..OVERALL_WIDTH ] . fill ( F :: ZERO ) ;
1014
1019
}
1015
1020
1016
- for ( i, ( workload_row, result) ) in record
1021
+ // To fill the WorkloadRows we do 2 passes:
1022
+ // - First, a serial pass to fill some of the records into the trace
1023
+ // - Then, a parallel pass to fill the rest of the records into the trace
1024
+ // Note, the first pass is done to avoid overwriting the records
1025
+
1026
+ // Copy of `a_write_prev_data` to avoid overwriting it and to use it in the parallel
1027
+ // pass
1028
+ let a_prev_data = if !is_init {
1029
+ let mut tmp = Vec :: with_capacity ( length) ;
1030
+ tmp. extend_from_slice ( record. a_write_prev_data ) ;
1031
+ tmp
1032
+ } else {
1033
+ vec ! [ ]
1034
+ } ;
1035
+
1036
+ for ( i, ( workload_row, row_chunk) ) in record
1017
1037
. workload
1018
1038
. iter ( )
1019
- . zip ( results . into_iter ( ) )
1039
+ . zip ( workload_chunk . chunks_exact_mut ( OVERALL_WIDTH ) )
1020
1040
. enumerate ( )
1021
1041
. rev ( )
1022
1042
{
1023
- let offset = i * OVERALL_WIDTH ;
1024
- let cols: & mut WorkloadCols < F > =
1025
- workload_chunk[ offset..offset + WL_WIDTH ] . borrow_mut ( ) ;
1043
+ let cols: & mut WorkloadCols < F > = row_chunk[ ..WL_WIDTH ] . borrow_mut ( ) ;
1026
1044
1027
1045
let timestamp = timestamp + ( ( length - i) * 2 ) as u32 ;
1028
1046
@@ -1032,32 +1050,59 @@ where
1032
1050
timestamp + 4 ,
1033
1051
cols. b_aux . as_mut ( ) ,
1034
1052
) ;
1035
- cols. b = workload_row. b ;
1036
1053
1037
- if !is_init {
1038
- cols. a_aux . set_prev_data ( workload_row. a_aux . prev_data ) ;
1039
- }
1054
+ // We temporarily store the result here
1055
+ // the correct value of b is computed during the serial pass below
1056
+ cols. b = record. workload [ i] . result ;
1057
+
1040
1058
mem_helper. fill (
1041
1059
workload_row. a_aux . prev_timestamp ,
1042
1060
timestamp + 3 ,
1043
1061
cols. a_aux . as_mut ( ) ,
1044
1062
) ;
1045
-
1046
- cols. prefix . data . alpha = alpha;
1047
- cols. prefix . data . result = result;
1048
- cols. prefix . data . idx = F :: from_canonical_usize ( i) ;
1049
- cols. prefix . data . b_ptr = b_ptr + F :: from_canonical_usize ( ( length - i) * EXT_DEG ) ;
1050
- cols. prefix . data . write_a = write_a;
1051
- cols. prefix . data . a_ptr = a_ptr + F :: from_canonical_usize ( length - i) ;
1052
-
1053
1063
cols. prefix . a_or_is_first = workload_row. a ;
1054
1064
1055
- cols. prefix . general . timestamp = F :: from_canonical_u32 ( timestamp) ;
1056
- cols. prefix . general . is_ins_row = F :: ZERO ;
1057
- cols. prefix . general . is_workload_row = F :: ONE ;
1058
-
1059
- workload_chunk[ offset + WL_WIDTH ..offset + OVERALL_WIDTH ] . fill ( F :: ZERO ) ;
1065
+ if i > 0 {
1066
+ cols. prefix . data . result = record. workload [ i - 1 ] . result ;
1067
+ }
1060
1068
}
1069
+
1070
+ workload_chunk
1071
+ . par_chunks_exact_mut ( OVERALL_WIDTH )
1072
+ . enumerate ( )
1073
+ . for_each ( |( i, row_chunk) | {
1074
+ let cols: & mut WorkloadCols < F > = row_chunk[ ..WL_WIDTH ] . borrow_mut ( ) ;
1075
+ let timestamp = timestamp + ( ( length - i) * 2 ) as u32 ;
1076
+ if is_init {
1077
+ cols. a_aux . set_prev_data ( [ F :: ZERO ; 1 ] ) ;
1078
+ } else {
1079
+ cols. a_aux . set_prev_data ( [ a_prev_data[ i] ] ) ;
1080
+ }
1081
+
1082
+ // DataCols
1083
+ cols. prefix . data . a_ptr = F :: from_canonical_u32 ( a_ptr + ( length - i) as u32 ) ;
1084
+ cols. prefix . data . write_a = write_a;
1085
+ cols. prefix . data . b_ptr =
1086
+ F :: from_canonical_u32 ( b_ptr + ( ( length - i) * EXT_DEG ) as u32 ) ;
1087
+ cols. prefix . data . idx = F :: from_canonical_usize ( i) ;
1088
+ if i == 0 {
1089
+ cols. prefix . data . result = [ F :: ZERO ; EXT_DEG ] ;
1090
+ }
1091
+ cols. prefix . data . alpha = alpha;
1092
+
1093
+ // GeneralCols
1094
+ cols. prefix . general . is_workload_row = F :: ONE ;
1095
+ cols. prefix . general . is_ins_row = F :: ZERO ;
1096
+
1097
+ // WorkloadCols
1098
+ cols. prefix . general . timestamp = F :: from_canonical_u32 ( timestamp) ;
1099
+
1100
+ cols. b = FieldExtension :: subtract (
1101
+ FieldExtension :: add ( cols. b , elem_to_ext ( cols. prefix . a_or_is_first ) ) ,
1102
+ FieldExtension :: multiply ( cols. prefix . data . result , alpha) ,
1103
+ ) ;
1104
+ row_chunk[ WL_WIDTH ..OVERALL_WIDTH ] . fill ( F :: ZERO ) ;
1105
+ } ) ;
1061
1106
} ) ;
1062
1107
}
1063
1108
}
0 commit comments