@@ -629,19 +629,29 @@ impl Tip5 {
629629 /// Specifically, if the top 32 bits of a BFieldElement are all ones, then the bottom 32 bits
630630 /// are not uniformly distributed, and so they are dropped. This method invokes squeeze until
631631 /// enough uniform u32s have been sampled.
632+ ///
633+ /// # Panics
634+ ///
635+ /// Panics if `upper_bound` is not a power of two.
632636 pub fn sample_indices ( & mut self , upper_bound : u32 , num_indices : usize ) -> Vec < u32 > {
633- debug_assert ! ( upper_bound. is_power_of_two( ) ) ;
634- let mut indices = vec ! [ ] ;
635- let mut squeezed_elements = vec ! [ ] ;
636- while indices. len ( ) != num_indices {
637- if squeezed_elements. is_empty ( ) {
638- squeezed_elements = self . squeeze ( ) . into_iter ( ) . rev ( ) . collect_vec ( ) ;
637+ assert ! ( upper_bound. is_power_of_two( ) ) ;
638+
639+ let mut indices = Vec :: with_capacity ( num_indices) ;
640+ let mut buffer = const { [ BFieldElement :: ZERO ; RATE ] } ;
641+ let mut next_in_buffer = RATE ;
642+ while indices. len ( ) < num_indices {
643+ if next_in_buffer == RATE {
644+ buffer = self . squeeze ( ) ;
645+ next_in_buffer = 0 ;
639646 }
640- let element = squeezed_elements. pop ( ) . unwrap ( ) ;
641- if element != BFieldElement :: new ( BFieldElement :: MAX ) {
647+ let element = buffer[ next_in_buffer] ;
648+ next_in_buffer += 1 ;
649+
650+ if element != const { BFieldElement :: new ( BFieldElement :: MAX ) } {
642651 indices. push ( element. value ( ) as u32 % upper_bound) ;
643652 }
644653 }
654+
645655 indices
646656 }
647657
@@ -653,18 +663,13 @@ impl Tip5 {
653663 /// [rate]: Sponge::RATE
654664 pub fn sample_scalars ( & mut self , num_elements : usize ) -> Vec < XFieldElement > {
655665 let num_squeezes = ( num_elements * EXTENSION_DEGREE ) . div_ceil ( Self :: RATE ) ;
656- debug_assert ! (
657- num_elements * EXTENSION_DEGREE <= num_squeezes * Self :: RATE ,
658- "need {} elements but getting {}" ,
659- num_elements * EXTENSION_DEGREE ,
660- num_squeezes * Self :: RATE
661- ) ;
666+ debug_assert ! ( num_elements * EXTENSION_DEGREE <= num_squeezes * Self :: RATE ) ;
667+
662668 ( 0 ..num_squeezes)
663669 . flat_map ( |_| self . squeeze ( ) )
664- . collect_vec ( )
665- . chunks ( 3 )
670+ . tuples ( )
666671 . take ( num_elements)
667- . map ( |elem | XFieldElement :: new ( [ elem [ 0 ] , elem [ 1 ] , elem [ 2 ] ] ) )
672+ . map ( |( x0 , x1 , x2 ) | XFieldElement :: new ( [ x0 , x1 , x2 ] ) )
668673 . collect ( )
669674 }
670675}
0 commit comments