Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 119 additions & 84 deletions crates/stark-backend-v2/src/poly_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ pub fn eval_eq_uni_at_one<F: Field>(l_skip: usize, x: F) -> F {
res * F::ONE.halve().exp_u64(l_skip as u64)
}

/// Returns `eq_D(x, Z)` as a polynomial in `Z` in coefficient form.
/// Derived from `eq_D(x, Z)` being the Lagrange basis at `x`, which is the character sum over the
/// roots of unity.
///
/// If z in D, then `eq_D(x, z) = 1/N sum_{k=1}^N (x/z)^k = 1/N sum_{k=1}^N x^k
/// z^{N-k}`.
pub fn eq_uni_poly<F, EF>(l_skip: usize, x: EF) -> UnivariatePoly<EF>
where
F: Field,
EF: ExtensionField<F>,
{
let n_inv = F::ONE.halve().exp_u64(l_skip as u64);
let mut coeffs = x
.powers()
.skip(1)
.take(1 << l_skip)
.map(|x_pow| x_pow * n_inv)
.collect_vec();
coeffs.reverse();
coeffs[0] = n_inv.into();
UnivariatePoly::new(coeffs)
}

pub fn eval_in_uni<F: Field>(l_skip: usize, n: isize, z: F) -> F {
debug_assert!(n >= -(l_skip as isize));
if n.is_negative() {
Expand Down Expand Up @@ -107,6 +130,11 @@ where
res
}

pub fn eq_sharp_uni_poly<EF: TwoAdicField>(xi_1: &[EF]) -> UnivariatePoly<EF> {
let evals = evals_eq_hypercube_serial(xi_1);
UnivariatePoly::from_evals_idft(&evals)
}

/// `\kappa_\rot(x, y)` should equal `\delta_{x,rot(y)}` on hyperprism.
///
/// `omega_pows` must have length `2^{l_skip}`.
Expand Down Expand Up @@ -142,7 +170,7 @@ pub fn eval_eq_rot_cube<F: Field>(x: &[F], y: &[F]) -> (F, F) {
#[derive(Clone, Debug)]
pub struct UnivariatePoly<F>(pub(crate) Vec<F>);

impl<F: TwoAdicField> UnivariatePoly<F> {
impl<F> UnivariatePoly<F> {
pub fn new(coeffs: Vec<F>) -> Self {
Self(coeffs)
}
Expand All @@ -158,11 +186,101 @@ impl<F: TwoAdicField> UnivariatePoly<F> {
pub fn into_coeffs(self) -> Vec<F> {
self.0
}
}

impl<F: Field> UnivariatePoly<F> {
pub fn eval_at_point<EF: ExtensionField<F>>(&self, x: EF) -> EF {
horner_eval(&self.0, x)
}

#[instrument(level = "debug", skip_all)]
pub fn lagrange_interpolate<BF: Field>(points: &[BF], evals: &[F]) -> Self
where
F: ExtensionField<BF>,
{
assert_eq!(points.len(), evals.len());
let len = points.len();

// Special case: empty or single evaluation
if len == 0 {
return Self(vec![]);
}
if len == 1 {
return Self(vec![evals[0]]);
}

// Lagrange interpolation algorithm
// P(x) = sum_{i=0}^{len-1} evals[i] * L_i(x)
// where L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])

// Step 1: Compute all denominators (points[i] - points[j]) for i != j
let mut denominators = Vec::with_capacity(len * (len - 1));
for i in 0..len {
for j in 0..len {
if i != j {
denominators.push(points[i] - points[j]);
}
}
}

// Step 2: Batch invert all denominators
let inv_denominators = batch_multiplicative_inverse_serial(&denominators);

// Step 3: Build coefficient form by accumulating Lagrange basis polynomials
let mut coeffs = vec![F::ZERO; len];

// Reusable workspace for Lagrange polynomial computation
let mut lagrange_poly = Vec::with_capacity(len);

#[allow(clippy::needless_range_loop)]
for i in 0..len {
// Skip if evaluation is zero (optimization)
if evals[i] == F::ZERO {
continue;
}

// Build L_i(x) in coefficient form using polynomial multiplication
// L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])

// Start with constant polynomial 1
lagrange_poly.clear();
lagrange_poly.push(F::ONE);

// Get the precomputed inverse denominators for this i
let inv_denom_start = i * (len - 1);
let mut inv_idx = 0;

// Multiply by (x - points[j]) / (points[i] - points[j]) for each j != i
#[allow(clippy::needless_range_loop)]
for j in 0..len {
if i != j {
let scale = inv_denominators[inv_denom_start + inv_idx];
inv_idx += 1;

// Multiply lagrange_poly by (x - points[j]) * scale in place
// This is equivalent to: lagrange_poly * (x - points[j]) * scale
// = lagrange_poly * x * scale - lagrange_poly * points[j] * scale

lagrange_poly.push(F::ZERO); // Extend by one for the new highest degree term
for k in (1..lagrange_poly.len()).rev() {
let prev_coeff = lagrange_poly[k - 1] * scale;
lagrange_poly[k] += prev_coeff;
lagrange_poly[k - 1] = -prev_coeff * points[j];
}
}
}

// Add evals[i] * L_i(x) to the result
for (k, &coeff) in lagrange_poly.iter().enumerate() {
coeffs[k] += evals[i] * coeff;
}
}

Self(coeffs)
}
}

impl<F: TwoAdicField> UnivariatePoly<F> {
/// Computes P(1), P(omega), ..., P(omega^{n-1}).
fn chirp_z(poly: &[F], omega: F, n: usize) -> Vec<F> {
if n == 0 {
Expand Down Expand Up @@ -302,89 +420,6 @@ impl<F: TwoAdicField> UnivariatePoly<F> {
Self(res)
}

#[instrument(level = "debug", skip_all)]
pub fn lagrange_interpolate(points: &[F], evals: &[F]) -> Self {
assert_eq!(points.len(), evals.len());
let len = points.len();

// Special case: empty or single evaluation
if len == 0 {
return Self(vec![]);
}
if len == 1 {
return Self(vec![evals[0]]);
}

// Lagrange interpolation algorithm
// P(x) = sum_{i=0}^{len-1} evals[i] * L_i(x)
// where L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])

// Step 1: Compute all denominators (points[i] - points[j]) for i != j
let mut denominators = Vec::with_capacity(len * (len - 1));
for i in 0..len {
for j in 0..len {
if i != j {
denominators.push(points[i] - points[j]);
}
}
}

// Step 2: Batch invert all denominators
let inv_denominators = batch_multiplicative_inverse_serial(&denominators);

// Step 3: Build coefficient form by accumulating Lagrange basis polynomials
let mut coeffs = vec![F::ZERO; len];

// Reusable workspace for Lagrange polynomial computation
let mut lagrange_poly = Vec::with_capacity(len);

#[allow(clippy::needless_range_loop)]
for i in 0..len {
// Skip if evaluation is zero (optimization)
if evals[i] == F::ZERO {
continue;
}

// Build L_i(x) in coefficient form using polynomial multiplication
// L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])

// Start with constant polynomial 1
lagrange_poly.clear();
lagrange_poly.push(F::ONE);

// Get the precomputed inverse denominators for this i
let inv_denom_start = i * (len - 1);
let mut inv_idx = 0;

// Multiply by (x - points[j]) / (points[i] - points[j]) for each j != i
#[allow(clippy::needless_range_loop)]
for j in 0..len {
if i != j {
let scale = inv_denominators[inv_denom_start + inv_idx];
inv_idx += 1;

// Multiply lagrange_poly by (x - points[j]) * scale in place
// This is equivalent to: lagrange_poly * (x - points[j]) * scale
// = lagrange_poly * x * scale - lagrange_poly * points[j] * scale

lagrange_poly.push(F::ZERO); // Extend by one for the new highest degree term
for k in (1..lagrange_poly.len()).rev() {
let prev_coeff = lagrange_poly[k - 1] * scale;
lagrange_poly[k] += prev_coeff;
lagrange_poly[k - 1] = -prev_coeff * points[j];
}
}
}

// Add evals[i] * L_i(x) to the result
for (k, &coeff) in lagrange_poly.iter().enumerate() {
coeffs[k] += evals[i] * coeff;
}
}

Self(coeffs)
}

/// Constructs the polynomial in coefficient form from its evaluations on a smooth subgroup of
/// `F^*` by performing inverse DFT.
///
Expand Down
14 changes: 4 additions & 10 deletions crates/stark-backend-v2/src/prover/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use crate::{
stacked_reduction::{prove_stacked_opening_reduction, StackedReductionCpu},
whir::WhirProver,
ColMajorMatrix, CommittedTraceDataV2, DeviceDataTransporterV2,
DeviceMultiStarkProvingKeyV2, DeviceStarkProvingKeyV2, LogupZerocheckCpu, MultiRapProver,
OpeningProverV2, ProverBackendV2, ProverDeviceV2, ProvingContextV2, TraceCommitterV2,
DeviceMultiStarkProvingKeyV2, DeviceStarkProvingKeyV2, MultiRapProver, OpeningProverV2,
ProverBackendV2, ProverDeviceV2, ProvingContextV2, TraceCommitterV2,
},
Digest, SystemParams, D_EF, EF, F,
};
Expand Down Expand Up @@ -69,16 +69,10 @@ impl<TS: FiatShamirTranscript> MultiRapProver<CpuBackendV2, TS> for CpuDeviceV2
transcript: &mut TS,
mpk: &DeviceMultiStarkProvingKeyV2<CpuBackendV2>,
ctx: &ProvingContextV2<CpuBackendV2>,
common_main_pcs_data: &StackedPcsData<F, Digest>,
_common_main_pcs_data: &StackedPcsData<F, Digest>,
) -> ((GkrProof, BatchConstraintProof), Vec<EF>) {
let (gkr_proof, batch_constraint_proof, r) =
prove_zerocheck_and_logup::<_, _, TS, LogupZerocheckCpu>(
self,
transcript,
mpk,
ctx,
common_main_pcs_data,
);
prove_zerocheck_and_logup(transcript, mpk, ctx);
((gkr_proof, batch_constraint_proof), r)
}
}
Expand Down
Loading