@@ -17,52 +17,82 @@ impl BitSet {
17
17
18
18
#[ inline( always) ]
19
19
pub fn insert ( & mut self , index : usize ) -> bool {
20
- let word_index = index / 64 ;
21
- let bit_index = index % 64 ;
20
+ let word_index = index >> 6 ;
21
+ let bit_index = index & 63 ;
22
22
let mask = 1u64 << bit_index;
23
23
24
- let was_set = ( self . words [ word_index] & mask) != 0 ;
25
- self . words [ word_index] |= mask;
24
+ debug_assert ! ( word_index < self . words. len( ) , "BitSet index out of bounds" ) ;
25
+
26
+ // SAFETY: word_index is derived from a memory address that is bounds-checked
27
+ // during memory access. The bitset is sized to accommodate all valid
28
+ // memory addresses, so word_index is always within bounds.
29
+ let word = unsafe { self . words . get_unchecked_mut ( word_index) } ;
30
+ let was_set = ( * word & mask) != 0 ;
31
+ * word |= mask;
26
32
!was_set
27
33
}
28
34
29
35
/// Set all bits within [start, end) to 1, return the number of flipped bits.
36
+ /// Assumes start < end and end <= self.words.len() * 64.
30
37
#[ inline( always) ]
31
38
pub fn insert_range ( & mut self , start : usize , end : usize ) -> usize {
32
39
debug_assert ! ( start < end) ;
40
+ debug_assert ! ( end <= self . words. len( ) * 64 , "BitSet range out of bounds" ) ;
41
+
33
42
let mut ret = 0 ;
34
- let start_word_index = start / u64:: BITS as usize ;
35
- let end_word_index = ( end - 1 ) / u64:: BITS as usize ;
36
- let start_bit = start as u32 % u64:: BITS ;
43
+ let start_word_index = start >> 6 ;
44
+ let end_word_index = ( end - 1 ) >> 6 ;
45
+ let start_bit = ( start & 63 ) as u32 ;
46
+
37
47
if start_word_index == end_word_index {
38
- let end_bit = ( end - 1 ) as u32 % u64 :: BITS + 1 ;
48
+ let end_bit = ( ( end - 1 ) & 63 ) as u32 + 1 ;
39
49
let mask_bits = end_bit - start_bit;
40
- let mask = ( u64:: MAX >> ( u64:: BITS - mask_bits) ) << start_bit;
41
- ret += mask_bits - ( self . words [ start_word_index] & mask) . count_ones ( ) ;
42
- self . words [ start_word_index] |= mask;
50
+ let mask = ( u64:: MAX >> ( 64 - mask_bits) ) << start_bit;
51
+ // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
52
+ // so start_word_index < self.words.len()
53
+ let word = unsafe { self . words . get_unchecked_mut ( start_word_index) } ;
54
+ ret += mask_bits - ( * word & mask) . count_ones ( ) ;
55
+ * word |= mask;
43
56
} else {
44
- let end_bit = end as u32 % u64 :: BITS ;
45
- let mask_bits = u64 :: BITS - start_bit;
57
+ let end_bit = ( end & 63 ) as u32 ;
58
+ let mask_bits = 64 - start_bit;
46
59
let mask = u64:: MAX << start_bit;
47
- ret += mask_bits - ( self . words [ start_word_index] & mask) . count_ones ( ) ;
48
- self . words [ start_word_index] |= mask;
60
+ // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
61
+ // so start_word_index < self.words.len()
62
+ let start_word = unsafe { self . words . get_unchecked_mut ( start_word_index) } ;
63
+ ret += mask_bits - ( * start_word & mask) . count_ones ( ) ;
64
+ * start_word |= mask;
65
+
49
66
let mask_bits = end_bit;
50
- let ( mask, _) = u64:: MAX . overflowing_shr ( u64:: BITS - end_bit) ;
51
- ret += mask_bits - ( self . words [ end_word_index] & mask) . count_ones ( ) ;
52
- self . words [ end_word_index] |= mask;
67
+ let mask = if end_bit == 0 {
68
+ 0
69
+ } else {
70
+ u64:: MAX >> ( 64 - end_bit)
71
+ } ;
72
+ // SAFETY: Caller ensures end <= self.words.len() * 64, so
73
+ // end_word_index < self.words.len()
74
+ let end_word = unsafe { self . words . get_unchecked_mut ( end_word_index) } ;
75
+ ret += mask_bits - ( * end_word & mask) . count_ones ( ) ;
76
+ * end_word |= mask;
53
77
}
78
+
54
79
if start_word_index + 1 < end_word_index {
55
80
for i in ( start_word_index + 1 ) ..end_word_index {
56
- ret += self . words [ i] . count_zeros ( ) ;
57
- self . words [ i] = u64:: MAX ;
81
+ // SAFETY: Caller ensures proper start and end, so i is within bounds
82
+ // of self.words.len()
83
+ let word = unsafe { self . words . get_unchecked_mut ( i) } ;
84
+ ret += word. count_zeros ( ) ;
85
+ * word = u64:: MAX ;
58
86
}
59
87
}
60
88
ret as usize
61
89
}
62
90
91
+ #[ inline( always) ]
63
92
pub fn clear ( & mut self ) {
64
- for item in self . words . iter_mut ( ) {
65
- * item = 0 ;
93
+ // SAFETY: words is valid for self.words.len() elements
94
+ unsafe {
95
+ std:: ptr:: write_bytes ( self . words . as_mut_ptr ( ) , 0 , self . words . len ( ) ) ;
66
96
}
67
97
}
68
98
}
@@ -132,6 +162,7 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
132
162
addr_space_access_count : vec ! [ 0 ; ( 1 << memory_dimensions. addr_space_height) + 1 ] ,
133
163
}
134
164
}
165
+
135
166
#[ inline( always) ]
136
167
pub fn clear ( & mut self ) {
137
168
self . page_indices . clear ( ) ;
@@ -147,6 +178,8 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
147
178
ptr : u32 ,
148
179
size : u32 ,
149
180
) {
181
+ debug_assert ! ( ( address_space as usize ) < self . addr_space_access_count. len( ) ) ;
182
+
150
183
let num_blocks = ( size + self . chunk - 1 ) >> self . chunk_bits ;
151
184
let start_chunk_id = ptr >> self . chunk_bits ;
152
185
let start_block_id = if self . chunk == 1 {
@@ -159,10 +192,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
159
192
let end_block_id = start_block_id + num_blocks;
160
193
let start_page_id = start_block_id >> PAGE_BITS ;
161
194
let end_page_id = ( ( end_block_id - 1 ) >> PAGE_BITS ) + 1 ;
195
+
162
196
for page_id in start_page_id..end_page_id {
163
197
if self . page_indices . insert ( page_id as usize ) {
164
198
self . page_access_count += 1 ;
165
- self . addr_space_access_count [ address_space as usize ] += 1 ;
199
+ // SAFETY: address_space passed is usually a hardcoded constant or derived from an
200
+ // Instruction where it is bounds checked before passing
201
+ unsafe {
202
+ * self
203
+ . addr_space_access_count
204
+ . get_unchecked_mut ( address_space as usize ) += 1 ;
205
+ }
166
206
}
167
207
}
168
208
}
@@ -185,38 +225,68 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
185
225
size_bits : u32 ,
186
226
num : u32 ,
187
227
) {
188
- let align_bits = self . as_byte_alignment_bits [ address_space as usize ] ;
228
+ debug_assert ! ( ( address_space as usize ) < self . as_byte_alignment_bits. len( ) ) ;
229
+
230
+ // SAFETY: address_space passed is usually a hardcoded constant or derived from an
231
+ // Instruction where it is bounds checked before passing
232
+ let align_bits = unsafe {
233
+ * self
234
+ . as_byte_alignment_bits
235
+ . get_unchecked ( address_space as usize )
236
+ } ;
189
237
debug_assert ! (
190
238
align_bits as u32 <= size_bits,
191
239
"align_bits ({}) must be <= size_bits ({})" ,
192
240
align_bits,
193
241
size_bits
194
242
) ;
243
+
195
244
for adapter_bits in ( align_bits as u32 + 1 ..=size_bits) . rev ( ) {
196
245
let adapter_idx = self . adapter_offset + adapter_bits as usize - 1 ;
197
- trace_heights[ adapter_idx] += num << ( size_bits - adapter_bits + 1 ) ;
246
+ debug_assert ! ( adapter_idx < trace_heights. len( ) ) ;
247
+ // SAFETY: trace_heights is initialized taking access adapters into account
248
+ unsafe {
249
+ * trace_heights. get_unchecked_mut ( adapter_idx) +=
250
+ num << ( size_bits - adapter_bits + 1 ) ;
251
+ }
198
252
}
199
253
}
200
254
201
255
/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
202
256
#[ inline( always) ]
203
257
pub ( crate ) fn lazy_update_boundary_heights ( & mut self , trace_heights : & mut [ u32 ] ) {
258
+ debug_assert ! ( self . boundary_idx < trace_heights. len( ) ) ;
259
+
204
260
// On page fault, assume we add all leaves in a page
205
261
let leaves = ( self . page_access_count << PAGE_BITS ) as u32 ;
206
- trace_heights[ self . boundary_idx ] += leaves;
262
+ // SAFETY: boundary_idx is a compile time constant within bounds
263
+ unsafe {
264
+ * trace_heights. get_unchecked_mut ( self . boundary_idx ) += leaves;
265
+ }
207
266
208
267
if let Some ( merkle_tree_idx) = self . merkle_tree_index {
268
+ debug_assert ! ( merkle_tree_idx < trace_heights. len( ) ) ;
269
+ debug_assert ! ( trace_heights. len( ) >= 2 ) ;
270
+
209
271
let poseidon2_idx = trace_heights. len ( ) - 2 ;
210
- trace_heights[ poseidon2_idx] += leaves * 2 ;
272
+ // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
273
+ unsafe {
274
+ * trace_heights. get_unchecked_mut ( poseidon2_idx) += leaves * 2 ;
275
+ }
211
276
212
277
let merkle_height = self . memory_dimensions . overall_height ( ) ;
213
278
let nodes = ( ( ( 1 << PAGE_BITS ) - 1 ) + ( merkle_height - PAGE_BITS ) ) as u32 ;
214
- trace_heights[ poseidon2_idx] += nodes * 2 ;
215
- trace_heights[ merkle_tree_idx] += nodes * 2 ;
279
+ // SAFETY: merkle_tree_idx is guaranteed to be in bounds
280
+ unsafe {
281
+ * trace_heights. get_unchecked_mut ( poseidon2_idx) += nodes * 2 ;
282
+ * trace_heights. get_unchecked_mut ( merkle_tree_idx) += nodes * 2 ;
283
+ }
216
284
}
217
285
self . page_access_count = 0 ;
286
+
218
287
for address_space in 0 ..self . addr_space_access_count . len ( ) {
219
- let x = self . addr_space_access_count [ address_space] ;
288
+ // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
289
+ let x = unsafe { * self . addr_space_access_count . get_unchecked ( address_space) } ;
220
290
if x > 0 {
221
291
// After finalize, we'll need to read it in chunk-sized units for the merkle chip
222
292
self . update_adapter_heights_batch (
@@ -225,7 +295,12 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
225
295
self . chunk_bits ,
226
296
( x << PAGE_BITS ) as u32 ,
227
297
) ;
228
- self . addr_space_access_count [ address_space] = 0 ;
298
+ // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
299
+ unsafe {
300
+ * self
301
+ . addr_space_access_count
302
+ . get_unchecked_mut ( address_space) = 0 ;
303
+ }
229
304
}
230
305
}
231
306
}
0 commit comments