Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions src/poly/evaluations/univariate/lagrange_interpolator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::poly::domain::vanishing_poly::VanishingPolynomial;
use ark_ff::{batch_inversion_and_mul, PrimeField};
use ark_std::vec::Vec;
use ark_std::{collections::HashMap, vec::Vec};
/// Struct describing Lagrange interpolation for a multiplicative coset I,
/// with |I| a power of 2.
/// TODO: Pull in lagrange poly explanation from libiop
Expand All @@ -11,6 +11,8 @@ pub struct LagrangeInterpolator<F: PrimeField> {
pub(crate) v_inv_elems: Vec<F>,
pub(crate) domain_vp: VanishingPolynomial<F>,
poly_evaluations: Vec<F>,
/// Cache for vanishing polynomial evaluations to avoid recomputing Z_H(x) for the same point
vp_cache: std::cell::RefCell<HashMap<F, F>>,
}

impl<F: PrimeField> LagrangeInterpolator<F> {
Expand Down Expand Up @@ -48,7 +50,6 @@ impl<F: PrimeField> LagrangeInterpolator<F> {
v_inv_i *= g_inv;
}

// TODO: Cache the intermediate terms with Z_H(x) evaluations.
let vp = VanishingPolynomial::new(domain_offset, domain_dim);

let lagrange_interpolation: LagrangeInterpolator<F> = LagrangeInterpolator {
Expand All @@ -57,6 +58,7 @@ impl<F: PrimeField> LagrangeInterpolator<F> {
v_inv_elems,
domain_vp: vp,
poly_evaluations,
vp_cache: std::cell::RefCell::new(HashMap::new()),
};
lagrange_interpolation
}
Expand All @@ -76,7 +78,19 @@ impl<F: PrimeField> LagrangeInterpolator<F> {
let r = self.all_domain_elems[i];
inverted_lagrange_coeffs.push(l * (interpolation_point - r));
}
let vp_t = self.domain_vp.evaluate(&interpolation_point);

// Cache the vanishing polynomial evaluation to avoid recomputing Z_H(x) for the same point
let vp_t = {
let mut cache = self.vp_cache.borrow_mut();
if let Some(&cached_vp) = cache.get(&interpolation_point) {
cached_vp
} else {
let computed_vp = self.domain_vp.evaluate(&interpolation_point);
cache.insert(interpolation_point, computed_vp);
computed_vp
}
};

let lagrange_coeffs = inverted_lagrange_coeffs.as_mut_slice();
batch_inversion_and_mul::<F>(lagrange_coeffs, &vp_t);
lagrange_coeffs.iter().cloned().collect()
Expand Down Expand Up @@ -143,4 +157,64 @@ mod tests {

assert_eq!(actual, expected)
}

#[test]
pub fn test_caching_efficiency() {
let mut rng = test_rng();
let poly = DensePolynomial::rand(15, &mut rng);
let gen = Fr::get_root_of_unity(1 << 4).unwrap();
let domain = Radix2DomainVar::new(
gen,
4, // 2^4 = 16
FpVar::constant(Fr::GENERATOR),
)
.unwrap();

// generate evaluations of `poly` on this domain
let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new();
for _ in 0..(1 << 4) {
oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen;
}

let interpolator = LagrangeInterpolator::new(
domain.offset().value().unwrap(),
domain.gen,
domain.dim,
oracle_evals,
);

// Test multiple interpolations at the same point to verify caching works
let interpolate_point = Fr::rand(&mut rng);

// First interpolation - should compute vanishing polynomial
let result1 = interpolator.interpolate(interpolate_point);

// Second interpolation at the same point - should use cached vanishing polynomial
let result2 = interpolator.interpolate(interpolate_point);

// Results should be identical
assert_eq!(result1, result2);

// Verify cache is populated
assert_eq!(interpolator.vp_cache.borrow().len(), 1);
assert!(interpolator.vp_cache.borrow().contains_key(&interpolate_point));

// Test interpolation at a different point
let different_point = Fr::rand(&mut rng);
let result3 = interpolator.interpolate(different_point);

// Cache should now contain both points
assert_eq!(interpolator.vp_cache.borrow().len(), 2);
assert!(interpolator.vp_cache.borrow().contains_key(&interpolate_point));
assert!(interpolator.vp_cache.borrow().contains_key(&different_point));

// Verify all results are correct
let expected1 = poly.evaluate(&interpolate_point);
let expected3 = poly.evaluate(&different_point);

assert_eq!(result1, expected1);
assert_eq!(result3, expected3);
}
}