Skip to content

Commit 2c0f6ca

Browse files
committed
perf: Avoid allocations when sampling from Tip5
1 parent 97fc0b3 commit 2c0f6ca

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

twenty-first/src/tip5/mod.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -629,19 +629,29 @@ impl Tip5 {
629629
/// Specifically, if the top 32 bits of a BFieldElement are all ones, then the bottom 32 bits
630630
/// are not uniformly distributed, and so they are dropped. This method invokes squeeze until
631631
/// enough uniform u32s have been sampled.
632+
///
633+
/// # Panics
634+
///
635+
/// Panics if `upper_bound` is not a power of two.
632636
pub fn sample_indices(&mut self, upper_bound: u32, num_indices: usize) -> Vec<u32> {
633-
debug_assert!(upper_bound.is_power_of_two());
634-
let mut indices = vec![];
635-
let mut squeezed_elements = vec![];
636-
while indices.len() != num_indices {
637-
if squeezed_elements.is_empty() {
638-
squeezed_elements = self.squeeze().into_iter().rev().collect_vec();
637+
assert!(upper_bound.is_power_of_two());
638+
639+
let mut indices = Vec::with_capacity(num_indices);
640+
let mut buffer = const { [BFieldElement::ZERO; RATE] };
641+
let mut next_in_buffer = RATE;
642+
while indices.len() < num_indices {
643+
if next_in_buffer == RATE {
644+
buffer = self.squeeze();
645+
next_in_buffer = 0;
639646
}
640-
let element = squeezed_elements.pop().unwrap();
641-
if element != BFieldElement::new(BFieldElement::MAX) {
647+
let element = buffer[next_in_buffer];
648+
next_in_buffer += 1;
649+
650+
if element != const { BFieldElement::new(BFieldElement::MAX) } {
642651
indices.push(element.value() as u32 % upper_bound);
643652
}
644653
}
654+
645655
indices
646656
}
647657

@@ -653,18 +663,13 @@ impl Tip5 {
653663
/// [rate]: Sponge::RATE
654664
pub fn sample_scalars(&mut self, num_elements: usize) -> Vec<XFieldElement> {
655665
let num_squeezes = (num_elements * EXTENSION_DEGREE).div_ceil(Self::RATE);
656-
debug_assert!(
657-
num_elements * EXTENSION_DEGREE <= num_squeezes * Self::RATE,
658-
"need {} elements but getting {}",
659-
num_elements * EXTENSION_DEGREE,
660-
num_squeezes * Self::RATE
661-
);
666+
debug_assert!(num_elements * EXTENSION_DEGREE <= num_squeezes * Self::RATE);
667+
662668
(0..num_squeezes)
663669
.flat_map(|_| self.squeeze())
664-
.collect_vec()
665-
.chunks(3)
670+
.tuples()
666671
.take(num_elements)
667-
.map(|elem| XFieldElement::new([elem[0], elem[1], elem[2]]))
672+
.map(|(x0, x1, x2)| XFieldElement::new([x0, x1, x2]))
668673
.collect()
669674
}
670675
}

0 commit comments

Comments
 (0)