1+ use std:: mem:: size_of;
2+
3+ use bytesize:: ByteSize ;
14use getset:: { Setters , WithSetters } ;
25use openvm_stark_backend:: p3_field:: PrimeField32 ;
36use p3_baby_bear:: BabyBear ;
@@ -7,8 +10,10 @@ pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
710
811pub const DEFAULT_MAX_TRACE_HEIGHT_BITS : u8 = 22 ;
912pub const DEFAULT_MAX_TRACE_HEIGHT : u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS ;
10- pub const DEFAULT_MAX_CELLS : usize = 1_200_000_000 ; // 1.2B
13+ pub const DEFAULT_MAX_MEMORY : usize = 15 << 30 ; // 15GiB
1114const DEFAULT_MAX_INTERACTIONS : usize = BabyBear :: ORDER_U32 as usize ;
15+ const DEFAULT_MAIN_CELL_WEIGHT : usize = 3 ; // 1 + 2^{log_blowup=1}
16+ const DEFAULT_INTERACTION_CELL_WEIGHT : usize = 8 ; // 2 * D_EF
1217
1318#[ derive( derive_new:: new, Clone , Debug , Serialize , Deserialize ) ]
1419pub struct Segment {
@@ -17,19 +22,36 @@ pub struct Segment {
1722 pub trace_heights : Vec < u32 > ,
1823}
1924
20- #[ derive( Clone , Debug , Default , WithSetters ) ]
25+ #[ derive( Clone , Debug , WithSetters ) ]
2126pub struct SegmentationConfig {
2227 pub limits : SegmentationLimits ,
23- /// Cells per row contributed by each interaction in cell count.
28+ /// Weight multiplier for main trace cells in memory calculation.
29+ #[ getset( set_with = "pub" ) ]
30+ pub main_cell_weight : usize ,
31+ /// Weight multiplier for interaction cells in memory calculation.
2432 #[ getset( set_with = "pub" ) ]
2533 pub interaction_cell_weight : usize ,
34+ /// Size of the base field in bytes. Used to convert cell count to memory bytes.
35+ #[ getset( set_with = "pub" ) ]
36+ pub base_field_size : usize ,
37+ }
38+
39+ impl Default for SegmentationConfig {
40+ fn default ( ) -> Self {
41+ Self {
42+ limits : SegmentationLimits :: default ( ) ,
43+ main_cell_weight : DEFAULT_MAIN_CELL_WEIGHT ,
44+ interaction_cell_weight : DEFAULT_INTERACTION_CELL_WEIGHT ,
45+ base_field_size : size_of :: < u32 > ( ) ,
46+ }
47+ }
2648}
2749
2850#[ derive( Clone , Debug , WithSetters , Setters ) ]
2951pub struct SegmentationLimits {
3052 pub max_trace_height : u32 ,
3153 #[ getset( set = "pub" , set_with = "pub" ) ]
32- pub max_cells : usize ,
54+ pub max_memory : usize ,
3355 #[ getset( set_with = "pub" ) ]
3456 pub max_interactions : usize ,
3557}
@@ -38,21 +60,21 @@ impl Default for SegmentationLimits {
3860 fn default ( ) -> Self {
3961 Self {
4062 max_trace_height : DEFAULT_MAX_TRACE_HEIGHT ,
41- max_cells : DEFAULT_MAX_CELLS ,
63+ max_memory : DEFAULT_MAX_MEMORY ,
4264 max_interactions : DEFAULT_MAX_INTERACTIONS ,
4365 }
4466 }
4567}
4668
4769impl SegmentationLimits {
48- pub fn new ( max_trace_height : u32 , max_cells : usize , max_interactions : usize ) -> Self {
70+ pub fn new ( max_trace_height : u32 , max_memory : usize , max_interactions : usize ) -> Self {
4971 debug_assert ! (
5072 max_trace_height. is_power_of_two( ) ,
5173 "max_trace_height should be a power of two"
5274 ) ;
5375 Self {
5476 max_trace_height,
55- max_cells ,
77+ max_memory ,
5678 max_interactions,
5779 }
5880 }
@@ -120,18 +142,26 @@ impl SegmentationCtx {
120142 self . config . limits . set_max_trace_height ( max_trace_height) ;
121143 }
122144
123- pub fn set_max_cells ( & mut self , max_cells : usize ) {
124- self . config . limits . max_cells = max_cells ;
145+ pub fn set_max_memory ( & mut self , max_memory : usize ) {
146+ self . config . limits . max_memory = max_memory ;
125147 }
126148
127149 pub fn set_max_interactions ( & mut self , max_interactions : usize ) {
128150 self . config . limits . max_interactions = max_interactions;
129151 }
130152
153+ pub fn set_main_cell_weight ( & mut self , weight : usize ) {
154+ self . config . main_cell_weight = weight;
155+ }
156+
131157 pub fn set_interaction_cell_weight ( & mut self , weight : usize ) {
132158 self . config . interaction_cell_weight = weight;
133159 }
134160
161+ pub fn set_base_field_size ( & mut self , base_field_size : usize ) {
162+ self . config . base_field_size = base_field_size;
163+ }
164+
135165 /// Calculate the maximum trace height and corresponding air name
136166 #[ inline( always) ]
137167 fn calculate_max_trace_height_with_name ( & self , trace_heights : & [ u32 ] ) -> ( u32 , & str ) {
@@ -144,22 +174,44 @@ impl SegmentationCtx {
144174 . unwrap_or ( ( 0 , "unknown" ) )
145175 }
146176
147- /// Calculate the total cells used based on trace heights and widths,
148- /// including weighted contribution from interactions if `interaction_cell_weight > 0`.
177+ /// Calculate total memory in bytes based on trace heights and widths.
178+ /// Formula: base_field_size * (main_cell_weight * main_cells + interaction_cell_weight *
179+ /// interaction_cells)
149180 #[ inline( always) ]
150- fn calculate_total_cells ( & self , trace_heights : & [ u32 ] ) -> usize {
181+ fn calculate_total_memory (
182+ & self ,
183+ trace_heights : & [ u32 ] ,
184+ ) -> (
185+ usize , /* memory */
186+ usize , /* main */
187+ usize , /* interaction */
188+ ) {
151189 debug_assert_eq ! ( trace_heights. len( ) , self . widths. len( ) ) ;
152190
191+ let main_weight = self . config . main_cell_weight ;
153192 let interaction_weight = self . config . interaction_cell_weight ;
154- trace_heights
193+ let base_field_size = self . config . base_field_size ;
194+
195+ let mut main_cnt = 0 ;
196+ let mut interaction_cnt = 0 ;
197+ for ( ( & height, & width) , & interactions) in trace_heights
155198 . iter ( )
156199 . zip ( self . widths . iter ( ) )
157200 . zip ( self . interactions . iter ( ) )
158- . map ( |( ( & height, & width) , & interactions) | {
159- let padded_height = height. next_power_of_two ( ) as usize ;
160- padded_height * ( width + interactions * interaction_weight)
161- } )
162- . sum ( )
201+ {
202+ let padded_height = height. next_power_of_two ( ) as usize ;
203+ main_cnt += padded_height * width;
204+ interaction_cnt += padded_height * interactions;
205+ }
206+
207+ let main_memory = main_cnt * main_weight * base_field_size;
208+ let interaction_memory =
209+ ( interaction_cnt + 1 ) . next_power_of_two ( ) * interaction_weight * base_field_size;
210+ (
211+ main_memory + interaction_memory,
212+ main_memory,
213+ interaction_memory,
214+ )
163215 }
164216
165217 /// Calculate the total interactions based on trace heights
@@ -199,8 +251,11 @@ impl SegmentationCtx {
199251 return false ;
200252 }
201253
254+ let main_weight = self . config . main_cell_weight ;
202255 let interaction_weight = self . config . interaction_cell_weight ;
203- let mut total_cells = 0 ;
256+ let base_field_size = self . config . base_field_size ;
257+ let mut main_cnt = 0usize ;
258+ let mut interaction_cnt = 0usize ;
204259 for ( i, ( ( ( padded_height, width) , interactions) , is_constant) ) in trace_heights
205260 . iter ( )
206261 . map ( |& height| height. next_power_of_two ( ) )
@@ -214,7 +269,7 @@ impl SegmentationCtx {
214269 if !is_constant && padded_height > self . config . limits . max_trace_height {
215270 let air_name = unsafe { self . air_names . get_unchecked ( i) } ;
216271 tracing:: info!(
217- "instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) " ,
272+ "overshoot: instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) " ,
218273 instret,
219274 padded_height,
220275 self . config. limits. max_trace_height,
@@ -223,23 +278,30 @@ impl SegmentationCtx {
223278 ) ;
224279 return true ;
225280 }
226- total_cells += padded_height as usize * ( width + interactions * interaction_weight) ;
281+ main_cnt += padded_height as usize * width;
282+ interaction_cnt += padded_height as usize * interactions;
227283 }
284+ // interaction rounding to match n_logup calculation
285+ let total_memory = ( main_cnt * main_weight
286+ + ( interaction_cnt + 1 ) . next_power_of_two ( ) * interaction_weight)
287+ * base_field_size;
228288
229- if total_cells > self . config . limits . max_cells {
289+ if total_memory > self . config . limits . max_memory {
230290 tracing:: info!(
231- "instret {:10} | total cells ({:10}) > max ({:10})" ,
291+ "overshoot: instret {:10} | total memory ({:10}) > max ({:10}) | main ({:10}) | interaction ({:10})" ,
232292 instret,
233- total_cells,
234- self . config. limits. max_cells
293+ total_memory,
294+ self . config. limits. max_memory,
295+ main_cnt,
296+ interaction_cnt
235297 ) ;
236298 return true ;
237299 }
238300
239301 let total_interactions = self . calculate_total_interactions ( trace_heights) ;
240302 if total_interactions > self . config . limits . max_interactions {
241303 tracing:: info!(
242- "instret {:10} | total interactions ({:10}) > max ({:10})" ,
304+ "overshoot: instret {:10} | total interactions ({:10}) > max ({:10})" ,
243305 instret,
244306 total_interactions,
245307 self . config. limits. max_interactions
@@ -403,18 +465,21 @@ impl SegmentationCtx {
403465 trace_heights : & [ u32 ] ,
404466 ) {
405467 let ( max_trace_height, air_name) = self . calculate_max_trace_height_with_name ( trace_heights) ;
406- let total_cells = self . calculate_total_cells ( trace_heights) ;
468+ let ( total_memory, main_memory, interaction_memory) =
469+ self . calculate_total_memory ( trace_heights) ;
407470 let total_interactions = self . calculate_total_interactions ( trace_heights) ;
408471 let utilization = self . calculate_trace_utilization ( trace_heights) ;
409472
410473 let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" } ;
411474
412475 tracing:: info!(
413- "Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}) | {:.2}% utilization{}" ,
476+ "Segment {:3} | instret {:10} | {:8} instructions | {:5} memory ({:5}, {:5}) | {:10} interactions | {:8} max height ({}) | {:.2}% utilization{}" ,
414477 self . segments. len( ) ,
415478 instret_start,
416479 num_insns,
417- total_cells,
480+ ByteSize :: b( total_memory as u64 ) ,
481+ ByteSize :: b( main_memory as u64 ) ,
482+ ByteSize :: b( interaction_memory as u64 ) ,
418483 total_interactions,
419484 max_trace_height,
420485 air_name,
0 commit comments