Skip to content

Commit 3245429

Browse files
committed
boilerplate + fix edge cases
1 parent 1d8cb80 commit 3245429

File tree

6 files changed

+148
-17
lines changed

6 files changed

+148
-17
lines changed

backend/src/mle/mle_group.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,11 @@ impl<'a, EF: ExtensionField<PF<EF>>> MleGroup<'a, EF> {
5555
Self::Ref(_) => None,
5656
}
5757
}
58+
59+
pub fn is_packed(&self) -> bool {
60+
match self {
61+
Self::Owned(owned) => owned.is_packed(),
62+
Self::Ref(r) => r.is_packed(),
63+
}
64+
}
5865
}

backend/src/mle/mle_single.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,26 @@ impl<'a, EF: ExtensionField<PF<EF>>> Mle<'a, EF> {
5252
Self::Ref(poly) => poly.pack(),
5353
}
5454
}
55+
56+
pub fn unpack(&'a self) -> Self {
57+
match self {
58+
Self::Owned(poly) => poly.unpack(),
59+
Self::Ref(poly) => poly.unpack(),
60+
}
61+
}
62+
63+
pub fn is_packed(&self) -> bool {
64+
self.by_ref().is_packed()
65+
}
66+
67+
pub fn n_vars(&self) -> usize {
68+
self.by_ref().n_vars()
69+
}
70+
71+
pub fn as_owned_or_clone(self) -> MleOwned<EF> {
72+
match self {
73+
Self::Owned(o) => o,
74+
Self::Ref(r) => r.clone_to_owned(),
75+
}
76+
}
5577
}

backend/src/mle/mle_single_owned.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use fiat_shamir::*;
22
use p3_field::ExtensionField;
33

4-
use crate::{pack_extension, Mle, MleRef, MultilinearPoint};
4+
use crate::{Mle, MleRef, MultilinearPoint, pack_extension, unpack_extension};
55
use p3_field::PackedValue;
66

77
#[derive(Debug, Clone)]
@@ -112,4 +112,45 @@ impl<EF: ExtensionField<PF<EF>>> MleOwned<EF> {
112112
Self::ExtensionPacked(_) => Mle::Ref(self.by_ref()),
113113
}
114114
}
115+
116+
pub fn unpack<'a>(&'a self) -> Mle<'a, EF> {
117+
match self {
118+
Self::Base(v) => Mle::Ref(MleRef::Base(v)),
119+
Self::Extension(v) => Mle::Ref(MleRef::Extension(v)),
120+
Self::BasePacked(pb) => Mle::Ref(MleRef::Base(PFPacking::<EF>::unpack_slice(pb))),
121+
Self::ExtensionPacked(ep) => Mle::Owned(MleOwned::Extension(unpack_extension(ep))),
122+
}
123+
}
124+
125+
pub fn is_packed(&self) -> bool {
126+
self.by_ref().is_packed()
127+
}
128+
129+
pub fn n_vars(&self) -> usize {
130+
self.by_ref().n_vars()
131+
}
132+
133+
pub fn halve(mut self) -> Self {
134+
match &mut self {
135+
Self::Base(v) => {
136+
v.truncate(v.len() / 2);
137+
}
138+
Self::Extension(v) => {
139+
v.truncate(v.len() / 2);
140+
}
141+
Self::BasePacked(v) => {
142+
if v.len() == 1 {
143+
return self.unpack().by_ref().clone_to_owned().halve();
144+
}
145+
v.truncate(v.len() / 2);
146+
}
147+
Self::ExtensionPacked(v) => {
148+
if v.len() == 1 {
149+
return self.unpack().by_ref().clone_to_owned().halve();
150+
}
151+
v.truncate(v.len() / 2);
152+
}
153+
}
154+
self
155+
}
115156
}

backend/src/mle/mle_single_ref.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,49 @@ impl<'a, EF: ExtensionField<PF<EF>>> MleRef<'a, EF> {
8686
Self::ExtensionPacked(_) => Mle::Ref(self.clone()),
8787
}
8888
}
89+
90+
pub fn unpack(&self) -> Mle<'a, EF> {
91+
match self {
92+
Self::Base(v) => Mle::Ref(MleRef::Base(v)),
93+
Self::Extension(v) => Mle::Ref(MleRef::Extension(v)),
94+
Self::BasePacked(pb) => Mle::Ref(MleRef::Base(PFPacking::<EF>::unpack_slice(pb))),
95+
Self::ExtensionPacked(ep) => Mle::Owned(MleOwned::Extension(unpack_extension(ep))),
96+
}
97+
}
98+
99+
pub fn pack_if(&self, cond: bool) -> Mle<'a, EF> {
100+
if cond {
101+
self.pack()
102+
} else {
103+
Mle::Ref(self.clone())
104+
}
105+
}
106+
107+
pub fn clone_to_owned(&self) -> MleOwned<EF> {
108+
match self {
109+
Self::Base(v) => MleOwned::Base(v.to_vec()),
110+
Self::Extension(v) => MleOwned::Extension(v.to_vec()),
111+
Self::BasePacked(pb) => MleOwned::BasePacked(pb.to_vec()),
112+
Self::ExtensionPacked(ep) => MleOwned::ExtensionPacked(ep.to_vec()),
113+
}
114+
}
115+
116+
pub fn fold(&self, scalars: &[EF]) -> MleOwned<EF> {
117+
match self {
118+
Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, scalars, &|a, b| b * a)),
119+
Self::Extension(pols) => {
120+
MleOwned::Extension(fold_multilinear(pols, scalars, &|a, b| b * a))
121+
}
122+
Self::BasePacked(pols) => {
123+
let scalars_packed = scalars
124+
.iter()
125+
.map(|&s| EFPacking::<EF>::from(s))
126+
.collect::<Vec<_>>();
127+
MleOwned::ExtensionPacked(fold_multilinear(pols, &scalars_packed, &|a, b| b * a))
128+
}
129+
Self::ExtensionPacked(pols) => {
130+
MleOwned::ExtensionPacked(fold_multilinear(pols, scalars, &|a, b| a * b))
131+
}
132+
}
133+
}
89134
}

sumcheck/src/product_computation.rs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ pub fn run_product_sumcheck<EF: ExtensionField<PF<EF>>>(
6161
EFPacking::<EF>::to_ext_iter([e]).collect()
6262
})
6363
}
64+
(MleRef::Base(evals), MleRef::Extension(weights)) => {
65+
compute_product_sumcheck_polynomial(evals, &weights, sum, |e| vec![e])
66+
}
6467
_ => unimplemented!(),
6568
};
6669

@@ -69,20 +72,34 @@ pub fn run_product_sumcheck<EF: ExtensionField<PF<EF>>>(
6972
let r1: EF = prover_state.sample();
7073
sum = first_sumcheck_poly.evaluate(r1);
7174

72-
if n_rounds < 2 {
73-
unimplemented!()
75+
if n_rounds == 1 {
76+
return (
77+
MultilinearPoint(vec![r1]),
78+
sum,
79+
pol_a.fold(&[EF::ONE - r1, r1]),
80+
pol_b.fold(&[EF::ONE - r1, r1]),
81+
);
7482
}
7583

7684
let (second_sumcheck_poly, folded) = match (pol_a, pol_b) {
7785
(MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => {
78-
fold_and_compute_product_sumcheck_polynomial(&evals, &weights, r1, sum, |e| {
79-
EFPacking::<EF>::to_ext_iter([e]).collect()
80-
})
86+
let (second_sumcheck_poly, folded) =
87+
fold_and_compute_product_sumcheck_polynomial(&evals, &weights, r1, sum, |e| {
88+
EFPacking::<EF>::to_ext_iter([e]).collect()
89+
});
90+
(second_sumcheck_poly, MleGroupOwned::ExtensionPacked(folded))
8191
}
8292
(MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => {
83-
fold_and_compute_product_sumcheck_polynomial(evals, &weights, r1, sum, |e| {
84-
EFPacking::<EF>::to_ext_iter([e]).collect()
85-
})
93+
let (second_sumcheck_poly, folded) =
94+
fold_and_compute_product_sumcheck_polynomial(evals, &weights, r1, sum, |e| {
95+
EFPacking::<EF>::to_ext_iter([e]).collect()
96+
});
97+
(second_sumcheck_poly, MleGroupOwned::ExtensionPacked(folded))
98+
}
99+
(MleRef::Base(evals), MleRef::Extension(weights)) => {
100+
let (second_sumcheck_poly, folded) =
101+
fold_and_compute_product_sumcheck_polynomial(evals, &weights, r1, sum, |e| vec![e]);
102+
(second_sumcheck_poly, MleGroupOwned::Extension(folded))
86103
}
87104
_ => unimplemented!(),
88105
};
@@ -94,7 +111,7 @@ pub fn run_product_sumcheck<EF: ExtensionField<PF<EF>>>(
94111

95112
let (mut challenges, folds, sum) = sumcheck_prove_many_rounds(
96113
1,
97-
MleGroupOwned::ExtensionPacked(folded),
114+
folded,
98115
Some(vec![EF::ONE - r2, r2]),
99116
&ProductComputation,
100117
&[],

sumcheck/src/prove.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,18 @@ where
127127
});
128128
(eq_point, eq_mle)
129129
});
130+
130131
let mut n_vars = multilinears.by_ref().n_vars();
131132
if let Some(prev_folding_factors) = &prev_folding_factors {
132133
n_vars -= log2_strict_usize(prev_folding_factors.len());
133134
}
134135
if let Some((eq_point, eq_mle)) = &eq_factor {
135136
assert_eq!(eq_point.len(), n_vars - skip + 1);
136137
assert_eq!(eq_mle.by_ref().n_vars(), eq_point.len() - 1);
137-
assert_eq!(
138-
eq_mle.by_ref().is_packed(),
139-
multilinears.by_ref().is_packed()
140-
);
138+
if eq_mle.by_ref().is_packed() && !multilinears.is_packed() {
139+
assert!(eq_point.len() < packing_log_width::<EF>());
140+
multilinears = multilinears.by_ref().unpack().into();
141+
}
141142
}
142143

143144
let mut challenges = Vec::new();
@@ -147,9 +148,7 @@ where
147148
// unpack
148149
multilinears = multilinears.by_ref().unpack().into();
149150
if let Some((_, eq_mle)) = &mut eq_factor {
150-
*eq_mle = MleOwned::Extension(unpack_extension(
151-
eq_mle.by_ref().as_extension_packed().unwrap(),
152-
));
151+
*eq_mle = eq_mle.by_ref().unpack().as_owned_or_clone();
153152
}
154153
}
155154

0 commit comments

Comments
 (0)