diff --git a/Cargo.toml b/Cargo.toml index f4ae6a7..2d7da7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,17 @@ edition = "2021" ark-ff = "0.4.0" ark-poly = "0.4.0" ark-std ="0.4.0" +ark-serialize = "0.4.2" +num-bigint = "0.4" +num-traits = "0.2" +zeroize = "1.8.1" + +[dev-dependencies] +criterion = "0.4" + +[features] +simd = [] [[bench]] -name = "explanation" harness = false +name = "reduce_sum_benches" \ No newline at end of file diff --git a/benches/explanation.rs b/benches/explanation.rs index 5323a06..4d8e432 100644 --- a/benches/explanation.rs +++ b/benches/explanation.rs @@ -1,5 +1,7 @@ fn main() { eprintln!("Error: This project uses a custom benchmarking workflow."); - eprintln!("Please navigate to the appropriate bench directory and call the shell './run_bench.sh' directly."); + eprintln!("Please choose a bench:"); + eprintln!(" Full Protocol Benches: 'cd ./benches/sumcheck-benches/ && cargo build --release && ./run_benches.sh'"); + eprintln!(" Lagrange Polynomial Benches: 'cd ./benches/lag-poly-benches/ && cargo build --release && ./run_benches.sh'"); std::process::exit(1); } diff --git a/benches/lag-poly-benches/Cargo.lock b/benches/lag-poly-benches/Cargo.lock index 01f4c9c..64e171e 100644 --- a/benches/lag-poly-benches/Cargo.lock +++ b/benches/lag-poly-benches/Cargo.lock @@ -334,11 +334,15 @@ checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "space-efficient-sumcheck" -version = "0.0.1" +version = "0.0.2" dependencies = [ "ark-ff", "ark-poly", + "ark-serialize", "ark-std", + "num-bigint", + "num-traits", + "zeroize", ] [[package]] diff --git a/benches/reduce_sum_benches.rs b/benches/reduce_sum_benches.rs new file mode 100644 index 0000000..cd149c4 --- /dev/null +++ b/benches/reduce_sum_benches.rs @@ -0,0 +1,95 @@ +#![feature(portable_simd)] + +use ark_std::{ + simd::{cmp::SimdPartialOrd, u32x4, Mask, Simd}, + test_rng, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use space_efficient_sumcheck::fields::{ + aarch64_neon::reduce_sum_32_bit_modulus_asm, reduce_sum_naive, VecOps, M31, M31_MODULUS, +}; + +// TODO (z-tech): this is the benchmark we should hit with both Neon and AVX +const LANES: usize = 4; +pub fn reduce_sum_packed(values: &[u32]) -> u32 { + let packed_modulus: Simd = u32x4::splat(M31_MODULUS); + let mut packed_sums1: Simd = u32x4::splat(0); + let mut packed_sums2: Simd = u32x4::splat(0); + let mut packed_sums3: Simd = u32x4::splat(0); + let mut packed_sums4: Simd = u32x4::splat(0); + for i in (0..values.len()).step_by(16) { + let tmp_packed_sums_1: Simd = + packed_sums1 + u32x4::from_slice(&values[i..i + 4]); + let tmp_packed_sums_2: Simd = + packed_sums2 + u32x4::from_slice(&values[i + 4..i + 8]); + let tmp_packed_sums_3: Simd = + packed_sums3 + u32x4::from_slice(&values[i + 8..i + 12]); + let tmp_packed_sums_4: Simd = + packed_sums4 + u32x4::from_slice(&values[i + 12..i + 16]); + let is_mod_needed_1: Mask = tmp_packed_sums_1.simd_ge(packed_modulus); + let is_mod_needed_2: Mask = tmp_packed_sums_2.simd_ge(packed_modulus); + let is_mod_needed_3: Mask = tmp_packed_sums_3.simd_ge(packed_modulus); + let is_mod_needed_4: Mask = tmp_packed_sums_4.simd_ge(packed_modulus); + packed_sums1 = + is_mod_needed_1.select(tmp_packed_sums_1 - packed_modulus, tmp_packed_sums_1); + packed_sums2 = + is_mod_needed_2.select(tmp_packed_sums_2 - packed_modulus, tmp_packed_sums_2); + packed_sums3 = + is_mod_needed_3.select(tmp_packed_sums_3 - packed_modulus, tmp_packed_sums_3); + packed_sums4 = + is_mod_needed_4.select(tmp_packed_sums_4 - packed_modulus, tmp_packed_sums_4); + } + reduce_sum_naive(&packed_sums1.to_array()) + + reduce_sum_naive(&packed_sums2.to_array()) + + reduce_sum_naive(&packed_sums3.to_array()) + + reduce_sum_naive(&packed_sums4.to_array()) +} + +fn reduce_sum_naive_bench(c: &mut Criterion) { + let random_values: Vec = (0..2_i32.pow(13)) + .map(|_| M31::rand(&mut test_rng()).to_u32()) + .collect(); + + c.bench_function("reduce_sum_naive", |b| { + b.iter(|| black_box(reduce_sum_naive(&random_values))) + }); +} + +fn reduce_sum_simd_lib(c: &mut Criterion) { + let random_values: Vec = (0..2_i32.pow(13)) + .map(|_| M31::rand(&mut test_rng()).to_u32()) + .collect(); + + c.bench_function("reduce_sum_simd_lib", |b| { + b.iter(|| black_box(reduce_sum_packed(&random_values))) + }); +} + +fn reduce_sum_neon_intrinsics(c: &mut Criterion) { + let random_values: Vec = (0..2_i32.pow(13)) + .map(|_| M31::rand(&mut test_rng())) + .collect(); + + c.bench_function("reduce_sum_neon_intrinsics", |b| { + b.iter(|| black_box(M31::reduce_sum(&random_values))) + }); +} + +fn reduce_sum_neon_asm(c: &mut Criterion) { + let random_values: Vec = (0..2_i32.pow(13)) + .map(|_| M31::rand(&mut test_rng()).to_u32()) + .collect(); + + c.bench_function("reduce_sum_neon_asm", |b| { + b.iter(|| black_box(reduce_sum_32_bit_modulus_asm(&random_values, M31_MODULUS))) + }); +} + +criterion_group!( + benches, + reduce_sum_naive_bench, + reduce_sum_simd_lib, + reduce_sum_neon_intrinsics, + reduce_sum_neon_asm, +); +criterion_main!(benches); diff --git a/src/fields/aarch64_neon/asm/mod.rs b/src/fields/aarch64_neon/asm/mod.rs new file mode 100644 index 0000000..0428ef8 --- /dev/null +++ b/src/fields/aarch64_neon/asm/mod.rs @@ -0,0 +1,74 @@ +use ark_std::arch::asm; + +use crate::fields::m31::reduce_sum_naive; + +pub fn reduce_sum_32_bit_modulus_asm(values: &[u32], modulus: u32) -> u32 { + let modulus: *const u32 = [modulus; 4].as_ptr(); + let mut sums: [u32; 4] = [0; 4]; + for step in (0..values.len()).step_by(4) { + let vals: *const u32 = unsafe { values.as_ptr().add(step) }; + + // TODO (z-tech): Again this should be unrolled, it's also important to understand if these loads / writes are not optimal + unsafe { + asm!( + // Load accumulated sums into register v0 + "ldr q0, [{0}]", + + // Load the new values into register v1 + "ldr q1, [{1}]", + + // Load the modulus into register v3 + "ldr q3, [{2}]", + + // Add values to accumulated sums and put result into v0 + "add v0.4s, v0.4s, v1.4s", + + // Subtract the modulus from the result and put it in v2 + "sub v2.4s, v0.4s, v3.4s", + + // Keep the minimum of those operations + "umin v0.4s, v0.4s, v2.4s", + + // Load it back into sum accumulator + "st1 {{v0.4s}}, [{0}]", + + inout(reg) sums.as_mut_ptr() => _, + in(reg) vals, + in(reg) modulus, + ); + } + } + + let arr: [u32; 4] = unsafe { core::mem::transmute(sums) }; + reduce_sum_naive(&arr) +} + +#[cfg(test)] +mod tests { + use crate::fields::{ + aarch64_neon::reduce_sum_32_bit_modulus_asm, + m31::{M31, M31_MODULUS}, + }; + use ark_ff::Zero; + use ark_std::test_rng; + + #[test] + fn reduce_sum_correctness() { + fn reduce_sum_sanity(vec: &[M31]) -> M31 { + M31::from(vec.iter().fold(M31::zero(), |acc, &x| (acc + x))) + } + + let mut rng = test_rng(); + let random_field_values: Vec = (0..1 << 13).map(|_| M31::rand(&mut rng)).collect(); + let random_field_values_u32: Vec = + random_field_values.iter().map(|m| m.to_u32()).collect(); + let exp = reduce_sum_sanity(&random_field_values); + assert_eq!( + exp, + M31::from(reduce_sum_32_bit_modulus_asm( + &random_field_values_u32, + M31_MODULUS + )) + ); + } +} diff --git a/src/fields/aarch64_neon/intrinsics/mod.rs b/src/fields/aarch64_neon/intrinsics/mod.rs new file mode 100644 index 0000000..7afdfa0 --- /dev/null +++ b/src/fields/aarch64_neon/intrinsics/mod.rs @@ -0,0 +1,86 @@ +use ark_std::{ + arch::aarch64::{ + uint32x4_t, vaddq_u32, vandq_u32, vcgeq_u32, vdupq_n_u32, vld1q_u32, vminq_u32, vmlsq_u32, + vmulq_u32, vqdmulhq_s32, vreinterpretq_s32_u32, vreinterpretq_u32_s32, vst1q_u32, + vsubq_u32, + }, + mem::transmute, +}; + +use crate::fields::m31::reduce_sum_naive; + +#[inline(always)] +fn sum_vectors(v0: &mut uint32x4_t, v1: &uint32x4_t, packed_modulus: &uint32x4_t) { + let raw_sum = unsafe { vaddq_u32(*v0, *v1) }; + let gte_mask = unsafe { vcgeq_u32(raw_sum, *packed_modulus) }; + *v0 = unsafe { vsubq_u32(raw_sum, vandq_u32(*packed_modulus, gte_mask)) }; + // an alternative to the above three lines is this, you can experiment to see which is more performant + // let sum1 = vaddq_u32(*v0, *v1); + // let sum2 = vsubq_u32(sum1, *packed_modulus); + // *v0 = vminq_u32(sum1, vandq_u32(*packed_modulus, sum2)); +} + +pub fn reduce_sum_32_bit_modulus(values: &[u32], modulus: u32) -> u32 { + let modulus: uint32x4_t = unsafe { transmute::<[u32; 4], uint32x4_t>([modulus; 4]) }; + let mut sums: uint32x4_t = unsafe { vdupq_n_u32(0) }; + + // TODO (z-tech): This should be unrolled, you have to figure out how much unrolling is the sweet spot (try 16, 32, ...) + for step in (0..values.len()).step_by(4) { + let v: uint32x4_t = unsafe { vld1q_u32(values.as_ptr().add(step)) }; + sum_vectors(&mut sums, &v, &modulus); + } + + let arr: [u32; 4] = unsafe { transmute(sums) }; + reduce_sum_naive(&arr) +} + +pub fn scalar_mult_32_bit_modulus(values: &mut [u32], scalar: u32, modulus: u32) { + let packed_modulus: uint32x4_t = unsafe { transmute::<[u32; 4], uint32x4_t>([modulus; 4]) }; + let packed_scalar: uint32x4_t = unsafe { transmute::<[u32; 4], uint32x4_t>([scalar; 4]) }; + for step in (0..values.len()).step_by(4) { + unsafe { + let lhs = vld1q_u32(values.as_ptr().add(step)); + let upper = vreinterpretq_u32_s32(vqdmulhq_s32( + vreinterpretq_s32_u32(lhs), + vreinterpretq_s32_u32(packed_scalar), + )); + let lower = vmulq_u32(lhs, packed_scalar); + let t = vmlsq_u32(lower, upper, packed_modulus); + let res = vminq_u32( + vmlsq_u32(lower, upper, packed_modulus), + vsubq_u32(t, packed_modulus), + ); + vst1q_u32(values.as_mut_ptr().add(step), res); + } + } +} + +#[cfg(test)] +mod tests { + use crate::fields::{ + aarch64_neon::reduce_sum_32_bit_modulus, + m31::{M31, M31_MODULUS}, + }; + use ark_ff::Zero; + use ark_std::test_rng; + + #[test] + fn reduce_sum_correctness() { + fn reduce_sum_sanity(vec: &[M31]) -> M31 { + M31::from(vec.iter().fold(M31::zero(), |acc, &x| (acc + x))) + } + + let mut rng = test_rng(); + let random_field_values: Vec = (0..1 << 13).map(|_| M31::rand(&mut rng)).collect(); + let random_field_values_u32: Vec = + random_field_values.iter().map(|m| m.to_u32()).collect(); + let exp = reduce_sum_sanity(&random_field_values); + assert_eq!( + exp, + M31::from(reduce_sum_32_bit_modulus( + &random_field_values_u32, + M31_MODULUS + )) + ); + } +} diff --git a/src/fields/aarch64_neon/mod.rs b/src/fields/aarch64_neon/mod.rs new file mode 100644 index 0000000..29ed88f --- /dev/null +++ b/src/fields/aarch64_neon/mod.rs @@ -0,0 +1,5 @@ +mod asm; +mod intrinsics; + +pub use asm::reduce_sum_32_bit_modulus_asm; +pub use intrinsics::{reduce_sum_32_bit_modulus, scalar_mult_32_bit_modulus}; diff --git a/src/fields/m31/fft_field.rs b/src/fields/m31/fft_field.rs new file mode 100644 index 0000000..d96a67c --- /dev/null +++ b/src/fields/m31/fft_field.rs @@ -0,0 +1,19 @@ +use super::M31; + +use ark_ff::FftField; + +// TODO (z-tech): These might be correct we must verify each one + +impl FftField for M31 { + const GENERATOR: Self = M31 { value: 7 }; + + const TWO_ADICITY: u32 = 1; + + const TWO_ADIC_ROOT_OF_UNITY: Self = M31 { value: 2147483646 }; + + const SMALL_SUBGROUP_BASE: Option = Some(3); + + const SMALL_SUBGROUP_BASE_ADICITY: Option = Some(1); + + const LARGE_SUBGROUP_ROOT_OF_UNITY: Option = Some(M31 { value: 6 }); +} diff --git a/src/fields/m31/field.rs b/src/fields/m31/field.rs new file mode 100644 index 0000000..43a62fe --- /dev/null +++ b/src/fields/m31/field.rs @@ -0,0 +1,123 @@ +use ark_ff::{Field, Zero}; +use ark_serialize::Flags; + +use crate::fields::m31::{M31, M31_MODULUS}; + +// TODO (z-tech): Each of these needs to implemented w/ tests + +impl Field for M31 { + type BasePrimeField = Self; + + type BasePrimeFieldIter = std::iter::Empty; + + const SQRT_PRECOMP: Option> = None; + + const ZERO: Self = Self { value: 0 }; + + const ONE: Self = Self { value: 1 }; + + fn double(&self) -> Self { + M31::from((2 * self.value) % M31_MODULUS) + } + + fn inverse(&self) -> Option { + if self.is_zero() { + return None; + } + + let x = *self; + let y = x.exp_power_of_2(2) * x; + let z = y.square() * y; + let a = z.exp_power_of_2(4) * z; + let b = a.exp_power_of_2(4); + let c = b * z; + let d = b.exp_power_of_2(4) * a; + let e = d.exp_power_of_2(12) * c; + let f = e.exp_power_of_2(3) * y; + Some(f) + } + + fn frobenius_map(&self, _: usize) -> M31 { + Self { value: self.value } + } + + fn extension_degree() -> u64 { + todo!() + } + + fn to_base_prime_field_elements(&self) -> Self::BasePrimeFieldIter { + todo!() + } + + fn from_base_prime_field_elems(_elems: &[Self::BasePrimeField]) -> Option { + todo!() + } + + fn from_base_prime_field(_elem: Self::BasePrimeField) -> Self { + todo!() + } + + fn double_in_place(&mut self) -> &mut Self { + todo!() + } + + fn neg_in_place(&mut self) -> &mut Self { + todo!() + } + + fn from_random_bytes_with_flags(_bytes: &[u8]) -> Option<(Self, F)> { + todo!() + } + + fn legendre(&self) -> ark_ff::LegendreSymbol { + todo!() + } + + fn square(&self) -> Self { + self.clone() * self.clone() + } + + fn square_in_place(&mut self) -> &mut Self { + todo!() + } + + fn inverse_in_place(&mut self) -> Option<&mut Self> { + todo!() + } + + fn frobenius_map_in_place(&mut self, _power: usize) { + todo!() + } + + fn characteristic() -> &'static [u64] { + &[] + } + + fn from_random_bytes(_bytes: &[u8]) -> Option { + std::unimplemented!() + } + + fn sqrt(&self) -> Option { + std::unimplemented!() + } + + fn sqrt_in_place(&mut self) -> Option<&mut Self> { + std::unimplemented!() + } + + fn sum_of_products(a: &[Self; T], b: &[Self; T]) -> Self { + let mut sum = Self::zero(); + for i in 0..a.len() { + sum += a[i] * b[i]; + } + sum + } + + fn pow>(&self, _exp: S) -> Self { + *self + } + + fn pow_with_table>(_powers_of_2: &[Self], _exp: S) -> Option { + std::unimplemented!() + } +} diff --git a/src/fields/m31/m31.rs b/src/fields/m31/m31.rs new file mode 100644 index 0000000..a1b161c --- /dev/null +++ b/src/fields/m31/m31.rs @@ -0,0 +1,172 @@ +use ark_ff::{Field, One, Zero}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, + CanonicalSerializeWithFlags, Flags, SerializationError, +}; +use ark_std::rand::{distributions::Standard, prelude::Distribution, Rng}; +use zeroize::Zeroize; + +use std::{ + fmt::{self, Display, Formatter}, + io::{Read, Write}, +}; + +// TODO (z-tech): Each of these should be verified w/ tests + +// The mersenne prime 2^31 - 1 +pub const M31_MODULUS: u32 = 2147483647; + +#[derive( + Copy, + Clone, + PartialEq, + Eq, + Debug, + PartialOrd, + Ord, + Hash, + CanonicalDeserialize, + CanonicalSerialize, +)] +pub struct M31 { + pub value: u32, +} + +impl M31 { + pub fn exp_power_of_2(&self, power_log: usize) -> Self { + let mut res = self.clone(); + for _ in 0..power_log { + res = res.square(); + } + res + } + pub fn rand(rng: &mut impl Rng) -> Self { + let value = rng.gen_range(0..M31_MODULUS); + M31 { value } + } +} + +impl Zero for M31 { + fn zero() -> Self { + M31::from(0) + } + fn is_zero(&self) -> bool { + self.value == 0 + } +} + +impl One for M31 { + fn one() -> Self { + M31::from(1) + } + fn is_one(&self) -> bool { + self.value == 1 + } +} + +impl Zeroize for M31 { + fn zeroize(&mut self) { + todo!() + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> M31 { + let value = rng.gen_range(0..M31_MODULUS as u64); + M31::from(value) + } +} + +impl Display for M31 { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.value, f) + } +} + +impl CanonicalDeserializeWithFlags for M31 { + #[inline] + fn deserialize_with_flags( + _reader: R, + ) -> Result<(Self, F), SerializationError> { + Ok((Self { value: 1 }, F::from_u8(1).unwrap())) + } +} + +impl CanonicalSerializeWithFlags for M31 { + #[inline] + fn serialize_with_flags( + &self, + _writer: W, + _flags: F, + ) -> Result<(), SerializationError> { + Ok(()) + } + + #[inline] + fn serialized_size_with_flags(&self) -> usize { + 1 + } +} + +impl Default for M31 { + fn default() -> Self { + M31::from(1_u32) + } +} + +#[cfg(test)] +mod tests { + use crate::fields::{ + m31::{M31, M31_MODULUS}, + vec_ops::VecOps, + }; + use ark_ff::{Field, One, Zero}; + use ark_std::{rand::Rng, test_rng}; + + #[test] + fn inverse_correctness() { + let a = M31::from(2); + assert_eq!(M31::from(1073741824), a.inverse().unwrap()); + } + + #[test] + fn reduce_sum_correctness() { + fn reduce_sum_sanity(vec: &[M31]) -> M31 { + M31::from(vec.iter().fold(M31::zero(), |acc, &x| (acc + x))) + } + + let mut rng = test_rng(); + let random_field_values: Vec = (0..1 << 13).map(|_| M31::rand(&mut rng)).collect(); + let exp = reduce_sum_sanity(&random_field_values); + assert_eq!(exp, M31::reduce_sum(&random_field_values)); + } + + #[test] + fn scalar_mult_correctness() { + fn test_field_values(mut rng: &mut impl Rng) -> (Vec, Vec) { + let mut exp: Vec = (0..(1 << 10)).map(|_| M31::rand(&mut rng)).collect(); + exp.push(M31::from(M31_MODULUS - 1)); + exp.push(M31::from(M31_MODULUS - 2)); + exp.push(M31::zero()); + exp.push(M31::one()); + (exp.clone(), exp) + } + fn scalar_mult_sanity(values: &mut [M31], scalar: M31) { + for elem in values.iter_mut() { + *elem = *elem * scalar; + } + } + + let mut rng = test_rng(); + let (mut exp, mut rec) = test_field_values(&mut rng); + for _ in 0..(1) { + // get a random scalar + let scalar = M31::rand(&mut rng); + // apply the scaling + scalar_mult_sanity(&mut exp, scalar); + M31::scalar_mult(&mut rec, scalar); + // check parity + assert_eq!(exp, rec); + } + } +} diff --git a/src/fields/m31/mod.rs b/src/fields/m31/mod.rs new file mode 100644 index 0000000..d3b2838 --- /dev/null +++ b/src/fields/m31/mod.rs @@ -0,0 +1,10 @@ +mod fft_field; +mod field; +mod m31; +mod ops; +mod prime_field; +mod transmute; +mod vec_ops; + +pub use m31::{M31, M31_MODULUS}; +pub use vec_ops::reduce_sum_naive; diff --git a/src/fields/m31/ops.rs b/src/fields/m31/ops.rs new file mode 100644 index 0000000..6ecb86a --- /dev/null +++ b/src/fields/m31/ops.rs @@ -0,0 +1,330 @@ +use ark_ff::Field; +use ark_std::{ + iter::{Product, Sum}, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +use crate::fields::m31::{M31, M31_MODULUS}; + +// TODO (z-tech): tests must be written for each of these +// NOTE (z-tech): wherever we can avoid the expensive % operator we should + +// std::ops by value +impl Add for M31 { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self::from((self.value + rhs.value) % M31_MODULUS) + } +} +impl Sub for M31 { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + if self.value < rhs.value { + // add the modulus + return Self::from((self.value + M31_MODULUS - rhs.value) % M31_MODULUS); + } + Self::from((self.value - rhs.value) % M31_MODULUS) + } +} +impl Mul for M31 { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + let mut product = self.to_u64() * rhs.to_u64(); + product = (product & M31_MODULUS as u64) + (product >> 31); + product = (product & M31_MODULUS as u64) + (product >> 31); + Self::from(product as u32) + } +} +impl Div for M31 { + type Output = Self; + fn div(self, rhs: Self) -> Self { + if rhs.value == 0 { + panic!("Division by zero"); + } + Self { + value: ((self.value as u64 * rhs.inverse().unwrap().value as u64) % M31_MODULUS as u64) + as u32, + } + } +} + +// std::ops by lifetimed reference +impl<'a> Add<&'a M31> for M31 { + type Output = Self; + fn add(self, rhs: &'a Self) -> Self { + Self::from((self.value + rhs.value) % M31_MODULUS) + } +} +impl<'a> Sub<&'a M31> for M31 { + type Output = Self; + fn sub(self, rhs: &'a Self) -> Self { + if self.value < rhs.value { + // add the modulus + return Self::from((self.value + M31_MODULUS - rhs.value) % M31_MODULUS); + } + Self::from((self.value - rhs.value) % M31_MODULUS) + } +} +impl<'a> Mul<&'a M31> for M31 { + type Output = Self; + fn mul(self, other: &'a Self) -> Self { + Self::from(((self.value as u64 * other.value as u64) % M31_MODULUS as u64) as u32) + } +} +impl<'a> Div<&'a M31> for M31 { + type Output = Self; + fn div(self, rhs: &'a Self) -> Self { + if rhs.value == 0 { + panic!("Division by zero"); + } + Self { + value: ((self.value as u64 * rhs.inverse().unwrap().value as u64) % M31_MODULUS as u64) + as u32, + } + } +} + +// std::ops by mut reference (NOTE: not the same as OpAssign below) +impl Add<&mut M31> for M31 { + type Output = M31; + fn add(self, rhs: &mut Self) -> Self::Output { + Self::from((self.value + rhs.value) % M31_MODULUS) + } +} +impl Sub<&mut M31> for M31 { + type Output = M31; + fn sub(self, rhs: &mut Self) -> Self::Output { + if self.value < rhs.value { + // add the modulus + return Self::from((self.value + M31_MODULUS - rhs.value) % M31_MODULUS); + } + Self::from((self.value - rhs.value) % M31_MODULUS) + } +} +impl Mul<&mut M31> for M31 { + type Output = M31; + fn mul(self, rhs: &mut Self) -> Self::Output { + Self::from(((self.value as u64 * rhs.value as u64) % M31_MODULUS as u64) as u32) + } +} +impl Div<&mut M31> for M31 { + type Output = M31; + fn div(self, rhs: &mut Self) -> Self::Output { + if rhs.value == 0 { + panic!("Division by zero"); + } + Self { + value: ((self.value as u64 * rhs.inverse().unwrap().value as u64) % M31_MODULUS as u64) + as u32, + } + } +} + +// std::AssignOp by mut reference +impl AddAssign for M31 { + fn add_assign(&mut self, other: M31) { + // Add the values and reduce modulo `modulus` + self.value = (self.value + other.value) % M31_MODULUS; + } +} +impl SubAssign for M31 { + fn sub_assign(&mut self, other: M31) { + // Perform subtraction and ensure it's non-negative by adding modulus if necessary + if self.value >= other.value { + self.value = (self.value - other.value) % M31_MODULUS; + } else { + self.value = (self.value + M31_MODULUS - other.value) % M31_MODULUS; + } + } +} +impl MulAssign for M31 { + fn mul_assign(&mut self, other: M31) { + // Multiply the values and reduce modulo `modulus` + self.value = (self.value * other.value) % M31_MODULUS; + } +} +impl DivAssign for M31 { + fn div_assign(&mut self, other: M31) { + if other.value != 0 { + self.value = (self.value / other.value) % M31_MODULUS; + } else { + panic!("Division by zero or no modular inverse exists"); + } + } +} + +impl<'a> AddAssign<&'a mut M31> for M31 { + fn add_assign(&mut self, other: &'a mut M31) { + self.value = (self.value.wrapping_add(other.value)) % M31_MODULUS; + } +} +impl<'a> SubAssign<&'a mut M31> for M31 { + fn sub_assign(&mut self, rhs: &'a mut M31) { + if self.value < rhs.value { + // add the modulus + self.value = (self.value + M31_MODULUS - rhs.value) % M31_MODULUS; + } else { + self.value = (self.value - rhs.value) % M31_MODULUS; + } + } +} +impl<'a> MulAssign<&'a mut M31> for M31 { + fn mul_assign(&mut self, other: &'a mut M31) { + self.value = ((self.value as u64 * other.value as u64) % M31_MODULUS as u64) as u32; + } +} +impl<'a> DivAssign<&'a mut M31> for M31 { + fn div_assign(&mut self, rhs: &'a mut M31) { + if rhs.value == 0 { + panic!("Division by zero"); + } + self.value = + ((self.value as u64 * rhs.inverse().unwrap().value as u64) % M31_MODULUS as u64) as u32; + } +} + +impl<'a> AddAssign<&'a M31> for M31 { + fn add_assign(&mut self, other: &'a M31) { + self.value = (self.value.wrapping_add(other.value)) % M31_MODULUS; + } +} +impl<'a> SubAssign<&'a M31> for M31 { + fn sub_assign(&mut self, other: &'a M31) { + self.value = (self.value.wrapping_sub(other.value)) % M31_MODULUS; + + // Handle negative results by adding modulus + if self.value > M31_MODULUS { + self.value += M31_MODULUS; + } + } +} +impl<'a> MulAssign<&'a M31> for M31 { + fn mul_assign(&mut self, other: &'a M31) { + self.value = (self.value.wrapping_mul(other.value)) % M31_MODULUS; + } +} +impl<'a> DivAssign<&'a M31> for M31 { + fn div_assign(&mut self, other: &'a M31) { + if other.value != 0 { + self.value = (self.value / other.value) % M31_MODULUS; + } else { + panic!("Division by zero or no modular inverse exists"); + } + } +} + +impl Neg for M31 { + type Output = Self; + + fn neg(self) -> Self { + Self::from(M31_MODULUS - self.value) + } +} +impl Product for M31 { + fn product(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter() + .fold(M31 { value: 1 }, |acc, item| acc * item) + } +} +impl Sum for M31 { + fn sum(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter() + .fold(M31 { value: 0 }, |acc, item| acc + item) + } +} + +impl<'a> Product<&'a M31> for M31 { + fn product(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter() + .fold(M31 { value: 1 }, |acc, item| acc * item) + } +} +impl<'a> Sum<&'a M31> for M31 { + fn sum(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter() + .fold(M31 { value: 0 }, |acc, item| acc + item) + } +} + +#[cfg(test)] +mod tests { + use crate::fields::m31::{M31, M31_MODULUS}; + use ark_ff::Field; + + #[test] + fn test_add() { + // basic + let a = M31::from(10); + let b = M31::from(22); + assert_eq!(a + b, M31::from(32)); + // larger than modulus + let c = M31::from(M31_MODULUS); + let d = M31::from(1); + assert_eq!(c + d, M31::from(1)); + // doesn't overflow + let e = M31::from(u32::MAX - 2); + let f = M31::from(3); + assert_eq!(e + f, M31::from(2)); + // doesn't overflow + let g = M31::from(M31_MODULUS - 1); + let h = M31::from(M31_MODULUS - 1); + assert_eq!(g + h, M31::from(2147483645)); + } + + #[test] + fn test_sub() { + // basic + let a = M31::from(22); + let b = M31::from(10); + assert_eq!(a - b, M31::from(12)); + // doesn't underflow + let c = M31::from(10); + let d = M31::from(22); + assert_eq!(c - d, M31::from(2147483635_u32)); + } + + #[test] + fn test_mul() { + // basic + let a = M31::from(10); + let b = M31::from(22); + assert_eq!(a * b, M31::from(220)); + // doesn't overflow + let c = M31::from(M31_MODULUS); + let d = M31::from(M31_MODULUS); + assert_eq!(c * d, M31::from(4611686014132420609_u64)); + } + + #[test] + fn test_div() { + // basic + let a = M31::from(10); + let b = M31::from(2); + assert_eq!(a / b, M31::from(5)); + // not divisor + let c = M31::from(10); + let d = M31::from(3); + assert_eq!(d.inverse().unwrap(), M31::from(1431655765)); + assert_eq!(c / d, M31::from(1431655768)); + } + + #[test] + #[should_panic(expected = "Division by zero")] + fn test_div_by_zero() { + let a = M31::from(10); + let b = M31::from(0); + let _result = a / b; + } +} diff --git a/src/fields/m31/prime_field.rs b/src/fields/m31/prime_field.rs new file mode 100644 index 0000000..42c0787 --- /dev/null +++ b/src/fields/m31/prime_field.rs @@ -0,0 +1,37 @@ +use ark_ff::{BigInt, BigInteger256, PrimeField}; + +use super::{M31, M31_MODULUS}; + +pub const M31_MODULUS_BIGINT4: BigInt<4> = BigInt::new([M31_MODULUS as u64, 0, 0, 0]); +pub const M31_MODULUS_MINUS_ONE_DIV_TWO_BIGINT4: BigInt<4> = + BigInt::new([(M31_MODULUS as u64 - 1) / 2, 0, 0, 0]); + +impl PrimeField for M31 { + type BigInt = BigInteger256; + + const MODULUS: Self::BigInt = M31_MODULUS_BIGINT4; + + const MODULUS_MINUS_ONE_DIV_TWO: Self::BigInt = M31_MODULUS_MINUS_ONE_DIV_TWO_BIGINT4; + + const MODULUS_BIT_SIZE: u32 = 32; + + const TRACE: Self::BigInt = BigInteger256::one(); + + const TRACE_MINUS_ONE_DIV_TWO: Self::BigInt = BigInteger256::one(); + + fn from_bigint(_repr: Self::BigInt) -> Option { + todo!() + } + + fn into_bigint(self) -> Self::BigInt { + todo!() + } + + fn from_be_bytes_mod_order(_bytes: &[u8]) -> Self { + Self { value: 0 } + } + + fn from_le_bytes_mod_order(_bytes: &[u8]) -> Self { + Self { value: 0 } + } +} diff --git a/src/fields/m31/transmute.rs b/src/fields/m31/transmute.rs new file mode 100644 index 0000000..25a99a3 --- /dev/null +++ b/src/fields/m31/transmute.rs @@ -0,0 +1,138 @@ +use ark_ff::{BigInt, BigInteger256}; +use ark_std::{num::ParseIntError, str::FromStr}; +use num_bigint::BigUint; + +use crate::fields::m31::{M31, M31_MODULUS}; + +// TODO (z-tech): tests must be written for each of these + +impl M31 { + pub fn to_u32(&self) -> u32 { + self.value as u32 + } +} + +impl M31 { + pub fn to_u64(&self) -> u64 { + self.value as u64 + } +} + +impl From for BigInt<4> { + fn from(field: M31) -> BigInt<4> { + BigInt::<4>([field.value as u64, 0, 0, 0]) + } +} + +impl From for M31 { + fn from(biguint: BigUint) -> Self { + let reduced_value = biguint % BigUint::from(M31_MODULUS); + let value = reduced_value.to_u32_digits().get(0).copied().unwrap_or(0); + M31::from(value) + } +} + +impl From for M31 { + fn from(bigint: BigInteger256) -> Self { + let bigint_u64 = bigint.0[0]; + let reduced_value = bigint_u64 % (M31_MODULUS as u64); + let value = reduced_value as u32; + M31::from(value) + } +} + +impl FromStr for M31 { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + let value = usize::from_str(s)?; + let reduced_value = value % M31_MODULUS as usize; + Ok(M31::from(reduced_value as u32)) + } +} + +impl From for BigUint { + fn from(element: M31) -> BigUint { + BigUint::from(element.value) + } +} + +impl From for M31 { + fn from(b: bool) -> Self { + M31 { + value: if b { 1 } else { 0 }, + } + } +} + +impl From for M31 { + fn from(value: u8) -> Self { + M31 { + value: value as u32, + } + } +} + +impl From for M31 { + fn from(value: u16) -> Self { + M31 { + value: value as u32, + } + } +} + +impl From for M31 { + fn from(value: u32) -> Self { + M31 { + value: if value == M31_MODULUS { + 0 + } else if value > M31_MODULUS { + value % M31_MODULUS + } else { + value + }, + } + } +} + +impl From for M31 { + fn from(value: i32) -> Self { + M31 { + value: if value == M31_MODULUS as i32 { + 0 + } else if value < 0 { + (M31_MODULUS as i32 - value) as u32 + } else { + value as u32 + }, + } + } +} + +impl From for M31 { + fn from(value: u64) -> Self { + M31 { + value: if value == M31_MODULUS as u64 { + 0 + } else if value > M31_MODULUS as u64 { + (value % M31_MODULUS as u64) as u32 + } else { + value as u32 + }, + } + } +} + +impl From for M31 { + fn from(value: u128) -> Self { + M31 { + value: if value == M31_MODULUS as u128 { + 0 + } else if value > M31_MODULUS as u128 { + (value % M31_MODULUS as u128) as u32 + } else { + value as u32 + }, + } + } +} diff --git a/src/fields/m31/vec_ops.rs b/src/fields/m31/vec_ops.rs new file mode 100644 index 0000000..9538a3e --- /dev/null +++ b/src/fields/m31/vec_ops.rs @@ -0,0 +1,52 @@ +use ark_std::slice::{from_raw_parts, from_raw_parts_mut}; + +use crate::fields::{ + m31::{M31, M31_MODULUS}, + vec_ops::VecOps, +}; + +#[cfg(target_arch = "aarch64")] +use crate::fields::aarch64_neon; + +pub fn reduce_sum_naive(vec: &[u32]) -> u32 { + let sum: u32 = vec.iter().fold(0, |acc, &x| { + let tmp = acc + x; + if tmp < M31_MODULUS { + return tmp; + } else { + return tmp - M31_MODULUS; + } + }); + sum +} + +impl VecOps for M31 { + fn reduce_sum(vec: &[M31]) -> Self { + #[cfg(target_arch = "aarch64")] + return M31 { + value: aarch64_neon::reduce_sum_32_bit_modulus( + unsafe { from_raw_parts(vec.as_ptr() as *mut u32, vec.len()) }, + M31_MODULUS, + ), + }; + + #[cfg(not(target_arch = "aarch64"))] + M31::from(reduce_sum_naive(unsafe { + from_raw_parts_mut(vec.as_ptr() as *mut u32, vec.len()) + })) + } + + fn scalar_mult(vec: &mut [Self], scalar: M31) { + #[cfg(target_arch = "aarch64")] + aarch64_neon::scalar_mult_32_bit_modulus( + unsafe { from_raw_parts_mut(vec.as_mut_ptr() as *mut u32, vec.len()) }, + scalar.to_u32(), + M31_MODULUS, + ); + + #[cfg(not(target_arch = "aarch64"))] + for elem in vec.iter_mut() { + *elem = *elem * scalar; + } + } +} diff --git a/src/fields/mod.rs b/src/fields/mod.rs new file mode 100644 index 0000000..1957361 --- /dev/null +++ b/src/fields/mod.rs @@ -0,0 +1,8 @@ +mod m31; +mod vec_ops; + +#[cfg(target_arch = "aarch64")] +pub mod aarch64_neon; + +pub use m31::{reduce_sum_naive, M31, M31_MODULUS}; +pub use vec_ops::VecOps; diff --git a/src/fields/vec_ops/mod.rs b/src/fields/vec_ops/mod.rs new file mode 100644 index 0000000..e5e811a --- /dev/null +++ b/src/fields/vec_ops/mod.rs @@ -0,0 +1,2 @@ +mod vec_ops; +pub use vec_ops::VecOps; diff --git a/src/fields/vec_ops/vec_ops.rs b/src/fields/vec_ops/vec_ops.rs new file mode 100644 index 0000000..b30e167 --- /dev/null +++ b/src/fields/vec_ops/vec_ops.rs @@ -0,0 +1,6 @@ +use ark_ff::Field; + +pub trait VecOps: Field { + fn reduce_sum(vec: &[Self]) -> Self; + fn scalar_mult(vec: &mut [Self], scalar: Self); +} diff --git a/src/lib.rs b/src/lib.rs index 86d3c63..272a426 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #[doc(hidden)] pub mod tests; +pub mod fields; pub mod hypercube; pub mod interpolation; pub mod messages;