@@ -158,13 +158,23 @@ impl BoolReader {
158158 // Do not inline this because inlining seems to worsen performance.
159159 #[ inline( never) ]
160160 pub ( crate ) fn read_bool ( & mut self , probability : u8 ) -> BitResult < bool > {
161- if let Some ( b) = self . fast ( ) . read_bit ( probability) {
161+ if let Some ( b) = self . fast ( ) . read_bool ( probability) {
162162 return BitResult :: ok ( b) ;
163163 }
164164
165165 self . cold_read_bool ( probability)
166166 }
167167
168+ // Do not inline this because inlining seems to worsen performance.
169+ #[ inline( never) ]
170+ pub ( crate ) fn read_flag ( & mut self ) -> BitResult < bool > {
171+ if let Some ( b) = self . fast ( ) . read_flag ( ) {
172+ return BitResult :: ok ( b) ;
173+ }
174+
175+ self . cold_read_flag ( )
176+ }
177+
168178 // Do not inline this because inlining seems to worsen performance.
169179 #[ inline( never) ]
170180 pub ( crate ) fn read_literal ( & mut self , n : u8 ) -> BitResult < u8 > {
@@ -206,13 +216,6 @@ impl BoolReader {
206216 self . cold_read_with_tree ( tree, usize:: from ( first_node. index ) )
207217 }
208218
209- // This should be inlined to allow it to share the instruction cache with
210- // `read_bool`, as both functions are short and called often.
211- #[ inline]
212- pub ( crate ) fn read_flag ( & mut self ) -> BitResult < bool > {
213- self . read_bool ( 128 )
214- }
215-
216219 // As a similar (but different) speedup to BitResult, the FastReader reads
217220 // bits under an assumption and validates it at the end.
218221 //
@@ -312,15 +315,21 @@ impl BoolReader {
312315 self . cold_read_bit ( probability)
313316 }
314317
318+ #[ cold]
319+ #[ inline( never) ]
320+ fn cold_read_flag ( & mut self ) -> BitResult < bool > {
321+ self . cold_read_bit ( 128 )
322+ }
323+
315324 #[ cold]
316325 #[ inline( never) ]
317326 fn cold_read_literal ( & mut self , n : u8 ) -> BitResult < u8 > {
318327 let mut v = 0u8 ;
319328 let mut res = self . start_accumulated_result ( ) ;
320329
321330 for _ in 0 ..n {
322- let b = self . cold_read_bit ( 128 ) . or_accumulate ( & mut res) ;
323- v = ( v << 1 ) + b as u8 ;
331+ let b = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
332+ v = ( v << 1 ) + u8 :: from ( b ) ;
324333 }
325334
326335 self . keep_accumulating ( res, v)
@@ -330,13 +339,13 @@ impl BoolReader {
330339 #[ inline( never) ]
331340 fn cold_read_optional_signed_value ( & mut self , n : u8 ) -> BitResult < i32 > {
332341 let mut res = self . start_accumulated_result ( ) ;
333- let flag = self . cold_read_bool ( 128 ) . or_accumulate ( & mut res) ;
342+ let flag = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
334343 if !flag {
335344 // We should not read further bits if the flag is not set.
336345 return self . keep_accumulating ( res, 0 ) ;
337346 }
338347 let magnitude = self . cold_read_literal ( n) . or_accumulate ( & mut res) ;
339- let sign = self . cold_read_bool ( 128 ) . or_accumulate ( & mut res) ;
348+ let sign = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
340349
341350 let value = if sign {
342351 -i32:: from ( magnitude)
@@ -380,24 +389,29 @@ impl FastReader<'_> {
380389 }
381390 }
382391
383- fn read_bit ( mut self , probability : u8 ) -> Option < bool > {
392+ fn read_bool ( mut self , probability : u8 ) -> Option < bool > {
384393 let bit = self . fast_read_bit ( probability) ;
385394 self . commit_if_valid ( bit)
386395 }
387396
397+ fn read_flag ( mut self ) -> Option < bool > {
398+ let value = self . fast_read_flag ( ) ;
399+ self . commit_if_valid ( value)
400+ }
401+
388402 fn read_literal ( mut self , n : u8 ) -> Option < u8 > {
389403 let value = self . fast_read_literal ( n) ;
390404 self . commit_if_valid ( value)
391405 }
392406
393407 fn read_optional_signed_value ( mut self , n : u8 ) -> Option < i32 > {
394- let flag = self . fast_read_bit ( 128 ) ;
408+ let flag = self . fast_read_flag ( ) ;
395409 if !flag {
396410 // We should not read further bits if the flag is not set.
397411 return self . commit_if_valid ( 0 ) ;
398412 }
399413 let magnitude = self . fast_read_literal ( n) ;
400- let sign = self . fast_read_bit ( 128 ) ;
414+ let sign = self . fast_read_flag ( ) ;
401415 let value = if sign {
402416 -i32:: from ( magnitude)
403417 } else {
@@ -467,11 +481,67 @@ impl FastReader<'_> {
467481 retval
468482 }
469483
484+ fn fast_read_flag ( & mut self ) -> bool {
485+ let State {
486+ mut chunk_index,
487+ mut value,
488+ mut range,
489+ mut bit_count,
490+ } = self . uncommitted_state ;
491+
492+ if bit_count < 0 {
493+ let chunk = self . chunks . get ( chunk_index) . copied ( ) ;
494+ // We ignore invalid data inside the `fast_` functions,
495+ // but we increase `chunk_index` below, so we can check
496+ // whether we read invalid data in `commit_if_valid`.
497+ let chunk = chunk. unwrap_or_default ( ) ;
498+
499+ let v = u32:: from_be_bytes ( chunk) ;
500+ chunk_index += 1 ;
501+ value <<= 32 ;
502+ value |= u64:: from ( v) ;
503+ bit_count += 32 ;
504+ }
505+ debug_assert ! ( bit_count >= 0 ) ;
506+
507+ let half_range = range / 2 ;
508+ let split = range - half_range;
509+ let bigsplit = u64:: from ( split) << bit_count;
510+
511+ let retval = if let Some ( new_value) = value. checked_sub ( bigsplit) {
512+ range = half_range;
513+ value = new_value;
514+ true
515+ } else {
516+ range = split;
517+ false
518+ } ;
519+ debug_assert ! ( range > 0 ) ;
520+
521+ // Compute shift required to satisfy `range >= 128`.
522+ // Apply that shift to `range` and `self.bitcount`.
523+ //
524+ // Subtract 24 because we only care about leading zeros in the
525+ // lowest byte of `range` which is a `u32`.
526+ let shift = range. leading_zeros ( ) . saturating_sub ( 24 ) ;
527+ range <<= shift;
528+ bit_count -= shift as i32 ;
529+ debug_assert ! ( range >= 128 ) ;
530+
531+ self . uncommitted_state = State {
532+ chunk_index,
533+ value,
534+ range,
535+ bit_count,
536+ } ;
537+ retval
538+ }
539+
470540 fn fast_read_literal ( & mut self , n : u8 ) -> u8 {
471541 let mut v = 0u8 ;
472542 for _ in 0 ..n {
473- let b = self . fast_read_bit ( 128 ) ;
474- v = ( v << 1 ) + b as u8 ;
543+ let b = self . fast_read_flag ( ) ;
544+ v = ( v << 1 ) + u8 :: from ( b ) ;
475545 }
476546 v
477547 }
@@ -502,7 +572,7 @@ mod tests {
502572 buf. as_mut_slice ( ) . as_flattened_mut ( ) [ ..size] . copy_from_slice ( & data[ ..] ) ;
503573 reader. init ( buf, size) . unwrap ( ) ;
504574 let mut res = reader. start_accumulated_result ( ) ;
505- assert_eq ! ( false , reader. read_bool ( 128 ) . or_accumulate( & mut res) ) ;
575+ assert_eq ! ( false , reader. read_flag ( ) . or_accumulate( & mut res) ) ;
506576 assert_eq ! ( true , reader. read_bool( 10 ) . or_accumulate( & mut res) ) ;
507577 assert_eq ! ( false , reader. read_bool( 250 ) . or_accumulate( & mut res) ) ;
508578 assert_eq ! ( 1 , reader. read_literal( 1 ) . or_accumulate( & mut res) ) ;
@@ -521,7 +591,7 @@ mod tests {
521591 buf. as_mut_slice ( ) . as_flattened_mut ( ) [ ..size] . copy_from_slice ( & data[ ..] ) ;
522592 reader. init ( buf, size) . unwrap ( ) ;
523593 let mut res = reader. start_accumulated_result ( ) ;
524- assert_eq ! ( false , reader. read_bool ( 128 ) . or_accumulate( & mut res) ) ;
594+ assert_eq ! ( false , reader. read_flag ( ) . or_accumulate( & mut res) ) ;
525595 assert_eq ! ( true , reader. read_bool( 10 ) . or_accumulate( & mut res) ) ;
526596 assert_eq ! ( false , reader. read_bool( 250 ) . or_accumulate( & mut res) ) ;
527597 assert_eq ! ( 1 , reader. read_literal( 1 ) . or_accumulate( & mut res) ) ;
0 commit comments