@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
5
5
6
6
pub const DEFAULT_SEGMENT_CHECK_INSNS : u64 = 1000 ;
7
7
8
- pub const DEFAULT_MAX_TRACE_HEIGHT : u32 = ( 1 << 23 ) - 10000 ;
8
+ pub const DEFAULT_MAX_TRACE_HEIGHT : u32 = 1 << 23 ;
9
9
pub const DEFAULT_MAX_CELLS : usize = 2_000_000_000 ; // 2B
10
10
const DEFAULT_MAX_INTERACTIONS : usize = BabyBear :: ORDER_U32 as usize ;
11
11
@@ -46,6 +46,10 @@ pub struct SegmentationCtx {
46
46
pub instret_last_segment_check : u64 ,
47
47
#[ getset( set_with = "pub" ) ]
48
48
pub segment_check_insns : u64 ,
49
+ /// Checkpoint of trace heights at last known state where all thresholds satisfied
50
+ pub ( crate ) checkpoint_trace_heights : Vec < u32 > ,
51
+ /// Instruction count at the checkpoint
52
+ checkpoint_instret : u64 ,
49
53
}
50
54
51
55
impl SegmentationCtx {
@@ -58,6 +62,7 @@ impl SegmentationCtx {
58
62
assert_eq ! ( air_names. len( ) , widths. len( ) ) ;
59
63
assert_eq ! ( air_names. len( ) , interactions. len( ) ) ;
60
64
65
+ let num_airs = air_names. len ( ) ;
61
66
Self {
62
67
segments : Vec :: new ( ) ,
63
68
air_names,
@@ -66,6 +71,8 @@ impl SegmentationCtx {
66
71
segmentation_limits,
67
72
segment_check_insns : DEFAULT_SEGMENT_CHECK_INSNS ,
68
73
instret_last_segment_check : 0 ,
74
+ checkpoint_trace_heights : vec ! [ 0 ; num_airs] ,
75
+ checkpoint_instret : 0 ,
69
76
}
70
77
}
71
78
@@ -77,6 +84,7 @@ impl SegmentationCtx {
77
84
assert_eq ! ( air_names. len( ) , widths. len( ) ) ;
78
85
assert_eq ! ( air_names. len( ) , interactions. len( ) ) ;
79
86
87
+ let num_airs = air_names. len ( ) ;
80
88
Self {
81
89
segments : Vec :: new ( ) ,
82
90
air_names,
@@ -85,6 +93,8 @@ impl SegmentationCtx {
85
93
segmentation_limits : SegmentationLimits :: default ( ) ,
86
94
segment_check_insns : DEFAULT_SEGMENT_CHECK_INSNS ,
87
95
instret_last_segment_check : 0 ,
96
+ checkpoint_trace_heights : vec ! [ 0 ; num_airs] ,
97
+ checkpoint_instret : 0 ,
88
98
}
89
99
}
90
100
@@ -100,37 +110,6 @@ impl SegmentationCtx {
100
110
self . segmentation_limits . max_interactions = max_interactions;
101
111
}
102
112
103
- /// Calculate the total cells used based on trace heights and widths
104
- #[ inline( always) ]
105
- fn calculate_total_cells ( & self , trace_heights : & [ u32 ] ) -> usize {
106
- debug_assert_eq ! ( trace_heights. len( ) , self . widths. len( ) ) ;
107
-
108
- // SAFETY: Length equality is asserted during initialization
109
- let widths_slice = unsafe { self . widths . get_unchecked ( ..trace_heights. len ( ) ) } ;
110
-
111
- trace_heights
112
- . iter ( )
113
- . zip ( widths_slice)
114
- . map ( |( & height, & width) | height as usize * width)
115
- . sum ( )
116
- }
117
-
118
- /// Calculate the total interactions based on trace heights and interaction counts
119
- #[ inline( always) ]
120
- fn calculate_total_interactions ( & self , trace_heights : & [ u32 ] ) -> usize {
121
- debug_assert_eq ! ( trace_heights. len( ) , self . interactions. len( ) ) ;
122
-
123
- // SAFETY: Length equality is asserted during initialization
124
- let interactions_slice = unsafe { self . interactions . get_unchecked ( ..trace_heights. len ( ) ) } ;
125
-
126
- trace_heights
127
- . iter ( )
128
- . zip ( interactions_slice)
129
- // We add 1 for the zero messages from the padding rows
130
- . map ( |( & height, & interactions) | ( height + 1 ) as usize * interactions)
131
- . sum ( )
132
- }
133
-
134
113
#[ inline( always) ]
135
114
fn should_segment (
136
115
& self ,
@@ -140,6 +119,8 @@ impl SegmentationCtx {
140
119
) -> bool {
141
120
debug_assert_eq ! ( trace_heights. len( ) , is_trace_height_constant. len( ) ) ;
142
121
debug_assert_eq ! ( trace_heights. len( ) , self . air_names. len( ) ) ;
122
+ debug_assert_eq ! ( trace_heights. len( ) , self . widths. len( ) ) ;
123
+ debug_assert_eq ! ( trace_heights. len( ) , self . interactions. len( ) ) ;
143
124
144
125
let instret_start = self
145
126
. segments
@@ -152,44 +133,51 @@ impl SegmentationCtx {
152
133
return false ;
153
134
}
154
135
155
- for ( i, ( & height, is_constant) ) in trace_heights
136
+ let mut total_cells = 0 ;
137
+ for ( i, ( ( padded_height, width) , is_constant) ) in trace_heights
156
138
. iter ( )
139
+ . map ( |& height| height. next_power_of_two ( ) )
140
+ . zip ( self . widths . iter ( ) )
157
141
. zip ( is_trace_height_constant. iter ( ) )
158
142
. enumerate ( )
159
143
{
160
- // Only segment if the height is not constant and exceeds the maximum height
161
- if !is_constant && height > self . segmentation_limits . max_trace_height {
162
- let air_name = & self . air_names [ i] ;
144
+ // Only segment if the height is not constant and exceeds the maximum height after
145
+ // padding
146
+ if !is_constant && padded_height > self . segmentation_limits . max_trace_height {
147
+ let air_name = unsafe { self . air_names . get_unchecked ( i) } ;
163
148
tracing:: info!(
164
- "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})" ,
165
- self . segments. len( ) ,
149
+ "instret {:9} | chip {} ({}) height ({:8}) > max ({:8})" ,
166
150
instret,
167
151
i,
168
152
air_name,
169
- height ,
153
+ padded_height ,
170
154
self . segmentation_limits. max_trace_height
171
155
) ;
172
156
return true ;
173
157
}
158
+ total_cells += padded_height as usize * width;
174
159
}
175
160
176
- let total_cells = self . calculate_total_cells ( trace_heights) ;
177
161
if total_cells > self . segmentation_limits . max_cells {
178
162
tracing:: info!(
179
- "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})" ,
180
- self . segments. len( ) ,
163
+ "instret {:9} | total cells ({:10}) > max ({:10})" ,
181
164
instret,
182
165
total_cells,
183
166
self . segmentation_limits. max_cells
184
167
) ;
185
168
return true ;
186
169
}
187
170
188
- let total_interactions = self . calculate_total_interactions ( trace_heights) ;
171
+ // All padding rows contribute a single message to the interactions (+1) since
172
+ // we assume chips don't send/receive with nonzero multiplicity on padding rows.
173
+ let total_interactions: usize = trace_heights
174
+ . iter ( )
175
+ . zip ( self . interactions . iter ( ) )
176
+ . map ( |( & height, & interactions) | ( height + 1 ) as usize * interactions)
177
+ . sum ( ) ;
189
178
if total_interactions > self . segmentation_limits . max_interactions {
190
179
tracing:: info!(
191
- "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})" ,
192
- self . segments. len( ) ,
180
+ "instret {:9} | total interactions ({:11}) > max ({:11})" ,
193
181
instret,
194
182
total_interactions,
195
183
self . segmentation_limits. max_interactions
@@ -204,16 +192,84 @@ impl SegmentationCtx {
204
192
pub fn check_and_segment (
205
193
& mut self ,
206
194
instret : u64 ,
207
- trace_heights : & [ u32 ] ,
195
+ trace_heights : & mut [ u32 ] ,
208
196
is_trace_height_constant : & [ bool ] ,
209
197
) -> bool {
210
- let ret = self . should_segment ( instret, trace_heights, is_trace_height_constant) ;
211
- if ret {
212
- self . segment ( instret, trace_heights) ;
198
+ let should_seg = self . should_segment ( instret, trace_heights, is_trace_height_constant) ;
199
+
200
+ if should_seg {
201
+ self . create_segment_from_checkpoint ( instret, trace_heights, is_trace_height_constant) ;
202
+ } else {
203
+ self . update_checkpoint ( instret, trace_heights) ;
213
204
}
205
+
214
206
self . instret_last_segment_check = instret;
207
+ should_seg
208
+ }
215
209
216
- ret
210
+ #[ inline( always) ]
211
+ fn create_segment_from_checkpoint (
212
+ & mut self ,
213
+ instret : u64 ,
214
+ trace_heights : & mut [ u32 ] ,
215
+ is_trace_height_constant : & [ bool ] ,
216
+ ) {
217
+ let instret_start = self
218
+ . segments
219
+ . last ( )
220
+ . map_or ( 0 , |s| s. instret_start + s. num_insns ) ;
221
+
222
+ let ( segment_instret, segment_heights) = if self . checkpoint_instret > instret_start {
223
+ (
224
+ self . checkpoint_instret ,
225
+ self . checkpoint_trace_heights . clone ( ) ,
226
+ )
227
+ } else {
228
+ // No valid checkpoint, use current values
229
+ ( instret, trace_heights. to_vec ( ) )
230
+ } ;
231
+
232
+ // Reset current trace heights and checkpoint
233
+ self . reset_trace_heights ( trace_heights, & segment_heights, is_trace_height_constant) ;
234
+ self . checkpoint_instret = 0 ;
235
+
236
+ tracing:: info!(
237
+ "Segment {:2} | instret {:9} | {} instructions" ,
238
+ self . segments. len( ) ,
239
+ instret_start,
240
+ segment_instret - instret_start
241
+ ) ;
242
+ self . segments . push ( Segment {
243
+ instret_start,
244
+ num_insns : segment_instret - instret_start,
245
+ trace_heights : segment_heights,
246
+ } ) ;
247
+ }
248
+
249
+ /// Resets trace heights by subtracting segment heights
250
+ #[ inline( always) ]
251
+ fn reset_trace_heights (
252
+ & self ,
253
+ trace_heights : & mut [ u32 ] ,
254
+ segment_heights : & [ u32 ] ,
255
+ is_trace_height_constant : & [ bool ] ,
256
+ ) {
257
+ for ( ( trace_height, & segment_height) , & is_trace_height_constant) in trace_heights
258
+ . iter_mut ( )
259
+ . zip ( segment_heights. iter ( ) )
260
+ . zip ( is_trace_height_constant. iter ( ) )
261
+ {
262
+ if !is_trace_height_constant {
263
+ * trace_height = trace_height. checked_sub ( segment_height) . unwrap ( ) ;
264
+ }
265
+ }
266
+ }
267
+
268
+ /// Updates the checkpoint with current safe state
269
+ #[ inline( always) ]
270
+ fn update_checkpoint ( & mut self , instret : u64 , trace_heights : & [ u32 ] ) {
271
+ self . checkpoint_trace_heights . copy_from_slice ( trace_heights) ;
272
+ self . checkpoint_instret = instret;
217
273
}
218
274
219
275
/// Try segment if there is at least one cycle
@@ -227,6 +283,12 @@ impl SegmentationCtx {
227
283
228
284
debug_assert ! ( num_insns > 0 , "Segment should contain at least one cycle" ) ;
229
285
286
+ tracing:: info!(
287
+ "Segment {:2} | instret {:9} | {} instructions [FINAL]" ,
288
+ self . segments. len( ) ,
289
+ instret_start,
290
+ num_insns
291
+ ) ;
230
292
self . segments . push ( Segment {
231
293
instret_start,
232
294
num_insns,
0 commit comments