diff --git a/src/poly/evaluations/univariate/lagrange_interpolator.rs b/src/poly/evaluations/univariate/lagrange_interpolator.rs index 3812f221..3b12a7c2 100644 --- a/src/poly/evaluations/univariate/lagrange_interpolator.rs +++ b/src/poly/evaluations/univariate/lagrange_interpolator.rs @@ -1,6 +1,6 @@ use crate::poly::domain::vanishing_poly::VanishingPolynomial; use ark_ff::{batch_inversion_and_mul, PrimeField}; -use ark_std::vec::Vec; +use ark_std::{collections::HashMap, vec::Vec}; /// Struct describing Lagrange interpolation for a multiplicative coset I, /// with |I| a power of 2. /// TODO: Pull in lagrange poly explanation from libiop @@ -11,6 +11,8 @@ pub struct LagrangeInterpolator { pub(crate) v_inv_elems: Vec, pub(crate) domain_vp: VanishingPolynomial, poly_evaluations: Vec, + /// Cache for vanishing polynomial evaluations to avoid recomputing Z_H(x) for the same point + vp_cache: std::cell::RefCell>, } impl LagrangeInterpolator { @@ -48,7 +50,6 @@ impl LagrangeInterpolator { v_inv_i *= g_inv; } - // TODO: Cache the intermediate terms with Z_H(x) evaluations. let vp = VanishingPolynomial::new(domain_offset, domain_dim); let lagrange_interpolation: LagrangeInterpolator = LagrangeInterpolator { @@ -57,6 +58,7 @@ impl LagrangeInterpolator { v_inv_elems, domain_vp: vp, poly_evaluations, + vp_cache: std::cell::RefCell::new(HashMap::new()), }; lagrange_interpolation } @@ -76,7 +78,19 @@ impl LagrangeInterpolator { let r = self.all_domain_elems[i]; inverted_lagrange_coeffs.push(l * (interpolation_point - r)); } - let vp_t = self.domain_vp.evaluate(&interpolation_point); + + // Cache the vanishing polynomial evaluation to avoid recomputing Z_H(x) for the same point + let vp_t = { + let mut cache = self.vp_cache.borrow_mut(); + if let Some(&cached_vp) = cache.get(&interpolation_point) { + cached_vp + } else { + let computed_vp = self.domain_vp.evaluate(&interpolation_point); + cache.insert(interpolation_point, computed_vp); + computed_vp + } + }; + let lagrange_coeffs = inverted_lagrange_coeffs.as_mut_slice(); batch_inversion_and_mul::(lagrange_coeffs, &vp_t); lagrange_coeffs.iter().cloned().collect() @@ -143,4 +157,64 @@ mod tests { assert_eq!(actual, expected) } + + #[test] + pub fn test_caching_efficiency() { + let mut rng = test_rng(); + let poly = DensePolynomial::rand(15, &mut rng); + let gen = Fr::get_root_of_unity(1 << 4).unwrap(); + let domain = Radix2DomainVar::new( + gen, + 4, // 2^4 = 16 + FpVar::constant(Fr::GENERATOR), + ) + .unwrap(); + + // generate evaluations of `poly` on this domain + let mut coset_point = domain.offset().value().unwrap(); + let mut oracle_evals = Vec::new(); + for _ in 0..(1 << 4) { + oracle_evals.push(poly.evaluate(&coset_point)); + coset_point *= gen; + } + + let interpolator = LagrangeInterpolator::new( + domain.offset().value().unwrap(), + domain.gen, + domain.dim, + oracle_evals, + ); + + // Test multiple interpolations at the same point to verify caching works + let interpolate_point = Fr::rand(&mut rng); + + // First interpolation - should compute vanishing polynomial + let result1 = interpolator.interpolate(interpolate_point); + + // Second interpolation at the same point - should use cached vanishing polynomial + let result2 = interpolator.interpolate(interpolate_point); + + // Results should be identical + assert_eq!(result1, result2); + + // Verify cache is populated + assert_eq!(interpolator.vp_cache.borrow().len(), 1); + assert!(interpolator.vp_cache.borrow().contains_key(&interpolate_point)); + + // Test interpolation at a different point + let different_point = Fr::rand(&mut rng); + let result3 = interpolator.interpolate(different_point); + + // Cache should now contain both points + assert_eq!(interpolator.vp_cache.borrow().len(), 2); + assert!(interpolator.vp_cache.borrow().contains_key(&interpolate_point)); + assert!(interpolator.vp_cache.borrow().contains_key(&different_point)); + + // Verify all results are correct + let expected1 = poly.evaluate(&interpolate_point); + let expected3 = poly.evaluate(&different_point); + + assert_eq!(result1, expected1); + assert_eq!(result3, expected3); + } }