diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e5eced54..10e382fb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - (`ark-poly`) Add fast polynomial division - (`ark-ec`) Improve GLV scalar multiplication performance by skipping leading zeroes. - (`ark-poly`) Make `SparsePolynomial.coeffs` field public +- [\#1044](https://github.com/arkworks-rs/algebra/pull/1044) Add implementation for small field with native integer types ### Breaking changes diff --git a/ff-macros/src/lib.rs b/ff-macros/src/lib.rs index 41e19f0bd..1efd9c4e2 100644 --- a/ff-macros/src/lib.rs +++ b/ff-macros/src/lib.rs @@ -12,6 +12,7 @@ use proc_macro::TokenStream; use syn::{Expr, ExprLit, Item, ItemFn, Lit, Meta}; mod montgomery; +mod small_fp; mod unroll; pub(crate) mod utils; @@ -74,6 +75,34 @@ pub fn mont_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .into() } +/// Derive the `SmallFpConfig` trait for small prime fields. +/// +/// The attributes available to this macro are: +/// * `modulus`: Specify the prime modulus underlying this prime field. +/// * `generator`: Specify the generator of the multiplicative subgroup. +/// * `backend`: Specify either "standard" or "montgomery" backend. +#[proc_macro_derive(SmallFpConfig, attributes(modulus, generator, backend))] +pub fn small_fp_config(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let modulus: u128 = fetch_attr("modulus", &ast.attrs) + .expect("Please supply a modulus attribute") + .parse() + .expect("Modulus should be a number"); + + let generator: u128 = fetch_attr("generator", &ast.attrs) + .expect("Please supply a generator attribute") + .parse() + .expect("Generator should be a number"); + + let backend: String = fetch_attr("backend", &ast.attrs) + .expect("Please supply a backend attribute") + .parse() + .expect("Backend should be a string"); + + small_fp::small_fp_config_helper(modulus, generator, backend, ast.ident).into() +} + const ARG_MSG: &str = "Failed to parse unroll threshold; must be a positive integer"; /// Attribute used to unroll for loops found inside a function block. diff --git a/ff-macros/src/small_fp/mod.rs b/ff-macros/src/small_fp/mod.rs new file mode 100644 index 000000000..f23546867 --- /dev/null +++ b/ff-macros/src/small_fp/mod.rs @@ -0,0 +1,52 @@ +mod montgomery_backend; +mod standard_backend; +mod utils; + +use quote::quote; + +/// This function is called by the `#[derive(SmallFp)]` macro and generates +/// the implementation of the `SmallFpConfig` +pub(crate) fn small_fp_config_helper( + modulus: u128, + generator: u128, + backend: String, + config_name: proc_macro2::Ident, +) -> proc_macro2::TokenStream { + let ty = match modulus { + m if m < 1u128 << 8 => quote! { u8 }, + m if m < 1u128 << 16 => quote! { u16 }, + m if m < 1u128 << 32 => quote! { u32 }, + m if m < 1u128 << 64 => quote! { u64 }, + _ => quote! { u128 }, + }; + + let backend_impl = match backend.as_str() { + "standard" => standard_backend::backend_impl(&ty, modulus, generator), + "montgomery" => { + if modulus >= 1u128 << 127 { + panic!( + "SmallFpConfig montgomery backend supports only moduli < 2^127. Use MontConfig with BigInt instead of SmallFp." + ) + } + montgomery_backend::backend_impl(&ty, modulus, generator) + }, + + _ => panic!("Unknown backend type: {}", backend), + }; + + let new_impl = match backend.as_str() { + "standard" => standard_backend::new(), + "montgomery" => montgomery_backend::new(modulus, ty), + _ => panic!("Unknown backend type: {}", backend), + }; + + quote! { + impl SmallFpConfig for #config_name { + #backend_impl + } + + impl #config_name { + #new_impl + } + } +} diff --git a/ff-macros/src/small_fp/montgomery_backend.rs b/ff-macros/src/small_fp/montgomery_backend.rs new file mode 100644 index 000000000..8b8815d78 --- /dev/null +++ b/ff-macros/src/small_fp/montgomery_backend.rs @@ -0,0 +1,258 @@ +use std::u32; + +use super::*; +use crate::small_fp::utils::{ + compute_two_adic_root_of_unity, compute_two_adicity, generate_montgomery_bigint_casts, + generate_sqrt_precomputation, mod_mul_const, +}; + +pub(crate) fn backend_impl( + ty: &proc_macro2::TokenStream, + modulus: u128, + generator: u128, +) -> proc_macro2::TokenStream { + let k_bits = 128 - modulus.leading_zeros(); + let r: u128 = 1u128 << k_bits; + let r_mod_n = r % modulus; + let r_mask = r - 1; + + let n_prime = mod_inverse_pow2(modulus, k_bits); + let one_mont = r_mod_n; + let generator_mont = mod_mul_const(generator % modulus, r_mod_n % modulus, modulus); + + let two_adicity = compute_two_adicity(modulus); + let two_adic_root = compute_two_adic_root_of_unity(modulus, two_adicity, generator); + let two_adic_root_mont = mod_mul_const(two_adic_root, r_mod_n, modulus); + + let neg_one_mont = mod_mul_const(modulus - 1, r_mod_n, modulus); + + let (from_bigint_impl, into_bigint_impl) = + generate_montgomery_bigint_casts(modulus, k_bits, r_mod_n); + let sqrt_precomp_impl = generate_sqrt_precomputation(modulus, two_adicity, Some(r_mod_n)); + + // Generate multiplication implementation based on type + let mul_impl = generate_mul_impl(ty, modulus, k_bits, r_mask, n_prime); + + quote! { + type T = #ty; + const MODULUS: Self::T = #modulus as Self::T; + const MODULUS_128: u128 = #modulus; + const GENERATOR: SmallFp = SmallFp::new(#generator_mont as Self::T); + const ZERO: SmallFp = SmallFp::new(0 as Self::T); + const ONE: SmallFp = SmallFp::new(#one_mont as Self::T); + const NEG_ONE: SmallFp = SmallFp::new(#neg_one_mont as Self::T); + + + const TWO_ADICITY: u32 = #two_adicity; + const TWO_ADIC_ROOT_OF_UNITY: SmallFp = SmallFp::new(#two_adic_root_mont as Self::T); + #sqrt_precomp_impl + + #[inline(always)] + fn add_assign(a: &mut SmallFp, b: &SmallFp) { + let (mut val, overflow) = a.value.overflowing_add(b.value); + + if overflow { + val = Self::T::MAX - Self::MODULUS + 1 + val + } + + if val >= Self::MODULUS { + val -= Self::MODULUS; + } + a.value = val; + } + + #[inline(always)] + fn sub_assign(a: &mut SmallFp, b: &SmallFp) { + if a.value >= b.value { + a.value -= b.value; + } else { + a.value = Self::MODULUS - (b.value - a.value); + } + } + + #[inline(always)] + fn double_in_place(a: &mut SmallFp) { + let tmp = *a; + Self::add_assign(a, &tmp); + } + + #[inline(always)] + fn neg_in_place(a: &mut SmallFp) { + if a.value != (0 as Self::T) { + a.value = Self::MODULUS - a.value; + } + } + + #mul_impl + + #[inline(always)] + fn sum_of_products( + a: &[SmallFp; T], + b: &[SmallFp; T],) -> SmallFp { + let mut acc = SmallFp::new(0 as Self::T); + for (x, y) in a.iter().zip(b.iter()) { + let mut prod = *x; + Self::mul_assign(&mut prod, y); + Self::add_assign(&mut acc, &prod); + } + acc + } + + #[inline(always)] + fn square_in_place(a: &mut SmallFp) { + let tmp = *a; + Self::mul_assign(a, &tmp); + } + + fn inverse(a: &SmallFp) -> Option> { + if a.value == 0 { + return None; + } + + let mut result = Self::ONE; + let mut base = *a; + let mut exp = Self::MODULUS - 2; + + while exp > 0 { + if exp & 1 == 1 { + Self::mul_assign(&mut result, &base); + } + + let mut sq = base; + Self::square_in_place(&mut sq); + base = sq; + exp >>= 1; + } + + Some(result) + } + + #from_bigint_impl + + #into_bigint_impl + } +} + +// Selects the appropriate multiplication algorithm at compile time: +// if modulus <= u64, multiply by casting to the next largest primitive +// otherwise, multiply in parts to form a 256-bit product before reduction +fn generate_mul_impl( + ty: &proc_macro2::TokenStream, + modulus: u128, + k_bits: u32, + r_mask: u128, + n_prime: u128, +) -> proc_macro2::TokenStream { + let ty_str = ty.to_string(); + + if ty_str == "u128" { + quote! { + #[inline(always)] + fn mul_assign(a: &mut SmallFp, b: &SmallFp) { + // 256-bit result stored as lo, hi + // t = a * b + let lolo = (a.value & 0xFFFFFFFFFFFFFFFF) * (b.value & 0xFFFFFFFFFFFFFFFF); + let lohi = (a.value & 0xFFFFFFFFFFFFFFFF) * (b.value >> 64); + let hilo = (a.value >> 64) * (b.value & 0xFFFFFFFFFFFFFFFF); + let hihi = (a.value >> 64) * (b.value >> 64); + + let (cross_sum, cross_carry) = lohi.overflowing_add(hilo); + let (mid, mid_carry) = lolo.overflowing_add(cross_sum << 64); + let t_lo = mid; + let t_hi = hihi + (cross_sum >> 64) + ((cross_carry as u128) << 64) + (mid_carry as u128); + + // m = t_lo * n_prime & r_mask + let m = t_lo.wrapping_mul(#n_prime) & #r_mask; + + // mn = m * modulus + let lolo = (m & 0xFFFFFFFFFFFFFFFF) * (#modulus & 0xFFFFFFFFFFFFFFFF); + let lohi = (m & 0xFFFFFFFFFFFFFFFF) * (#modulus >> 64); + let hilo = (m >> 64) * (#modulus & 0xFFFFFFFFFFFFFFFF); + let hihi = (m >> 64) * (#modulus >> 64); + + let (cross_sum, cross_carry) = lohi.overflowing_add(hilo); + let (mid, mid_carry) = lolo.overflowing_add(cross_sum << 64); + let mn_lo = mid; + let mn_hi = hihi + (cross_sum >> 64) + ((cross_carry as u128) << 64) + (mid_carry as u128); + + // (t + mn) / R + let (sum_lo, carry) = t_lo.overflowing_add(mn_lo); + let sum_hi = t_hi + mn_hi + (carry as u128); + + let mut u = (sum_lo >> #k_bits) | (sum_hi << (128 - #k_bits)); + u -= #modulus * (u >= #modulus) as u128; + a.value = u as Self::T; + } + } + } else { + let (mul_ty, bits) = match ty_str.as_str() { + "u8" => (quote! {u16}, 16u32), + "u16" => (quote! {u32}, 32u32), + "u32" => (quote! {u64}, 64u32), + _ => (quote! {u128}, 128u32), + }; + + let r_mask_downcast = quote! { #r_mask as #mul_ty }; + let n_prime_downcast = quote! { #n_prime as #mul_ty }; + let modulus_downcast = quote! { #modulus as #mul_ty }; + let one = quote! { 1 as #mul_ty }; + + quote! { + #[inline(always)] + fn mul_assign(a: &mut SmallFp, b: &SmallFp) { + let a_val = a.value as #mul_ty; + let b_val = b.value as #mul_ty; + + let t = a_val * b_val; + let t_low = t & #r_mask_downcast; + + // m = t_lo * n_prime & r_mask + let m = t_low.wrapping_mul(#n_prime_downcast) & #r_mask_downcast; + + // mn = m * modulus + let mn = m * #modulus_downcast; + + // (t + mn) / R + let (sum, overflow) = t.overflowing_add(mn); + let mut u = sum >> #k_bits; + + u += ((#one) << (#bits - #k_bits)) * (overflow as #mul_ty); + u -= #modulus_downcast * ((u >= #modulus_downcast) as #mul_ty); + a.value = u as Self::T; + } + } + } +} + +fn mod_inverse_pow2(n: u128, k_bits: u32) -> u128 { + let mut inv = 1u128; + for _ in 0..k_bits { + inv = inv.wrapping_mul(2u128.wrapping_sub(n.wrapping_mul(inv))); + } + let mask = (1u128 << k_bits) - 1; + inv.wrapping_neg() & mask +} + +pub(crate) fn new(modulus: u128, _ty: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let k_bits = 128 - modulus.leading_zeros(); + let r: u128 = 1u128 << k_bits; + let r_mod_n = r % modulus; + let r2 = mod_mul_const(r_mod_n, r_mod_n, modulus); + + quote! { + pub fn new(value: ::T) -> SmallFp { + let reduced_value = value % ::MODULUS; + let mut tmp = SmallFp::new(reduced_value); + let r2_elem = SmallFp::new(#r2 as ::T); + ::mul_assign(&mut tmp, &r2_elem); + tmp + } + + pub fn exit(a: &mut SmallFp) { + let mut tmp = *a; + let one = SmallFp::new(1 as ::T); + ::mul_assign(&mut tmp, &one); + a.value = tmp.value; + } + } +} diff --git a/ff-macros/src/small_fp/standard_backend.rs b/ff-macros/src/small_fp/standard_backend.rs new file mode 100644 index 000000000..a51260806 --- /dev/null +++ b/ff-macros/src/small_fp/standard_backend.rs @@ -0,0 +1,138 @@ +use super::*; +use crate::small_fp::utils::{ + compute_two_adic_root_of_unity, compute_two_adicity, generate_bigint_casts, + generate_sqrt_precomputation, +}; + +pub(crate) fn backend_impl( + ty: &proc_macro2::TokenStream, + modulus: u128, + generator: u128, +) -> proc_macro2::TokenStream { + let two_adicity = compute_two_adicity(modulus); + let two_adic_root_of_unity = compute_two_adic_root_of_unity(modulus, two_adicity, generator); + + let (from_bigint_impl, into_bigint_impl) = generate_bigint_casts(modulus); + let sqrt_precomp_impl = generate_sqrt_precomputation(modulus, two_adicity, None); + + quote! { + type T = #ty; + const MODULUS: Self::T = #modulus as Self::T; + const MODULUS_128: u128 = #modulus; + const GENERATOR: SmallFp = SmallFp::new(#generator as Self::T); + const ZERO: SmallFp = SmallFp::new(0 as Self::T); + const ONE: SmallFp = SmallFp::new(1 as Self::T); + const NEG_ONE: SmallFp = SmallFp::new((Self::MODULUS - 1) as Self::T); + + const TWO_ADICITY: u32 = #two_adicity; + const TWO_ADIC_ROOT_OF_UNITY: SmallFp = SmallFp::new(#two_adic_root_of_unity as Self::T); + #sqrt_precomp_impl + + #[inline(always)] + fn add_assign(a: &mut SmallFp, b: &SmallFp) { + a.value = match a.value.overflowing_add(b.value) { + (val, false) => val % Self::MODULUS, + (val, true) => (Self::T::MAX - Self::MODULUS + 1 + val) % Self::MODULUS, + }; + } + + #[inline(always)] + fn sub_assign(a: &mut SmallFp, b: &SmallFp) { + if a.value >= b.value { + a.value -= b.value; + } else { + a.value = Self::MODULUS - (b.value - a.value); + } + } + + #[inline(always)] + fn double_in_place(a: &mut SmallFp) { + let tmp = *a; + Self::add_assign(a, &tmp); + } + + #[inline(always)] + fn neg_in_place(a: &mut SmallFp) { + if a.value != (0 as Self::T) { + a.value = Self::MODULUS - a.value; + } + } + + #[inline(always)] + fn mul_assign(a: &mut SmallFp, b: &SmallFp) { + let a_128 = (a.value as u128) % #modulus; + let b_128 = (b.value as u128) % #modulus; + let mod_add = |x: u128, y: u128| -> u128 { + if x >= #modulus - y { + x - (#modulus - y) + } else { + x + y + } + }; + a.value = match a_128.overflowing_mul(b_128) { + (val, false) => (val % #modulus) as Self::T, + (_, true) => { + let mut result = 0u128; + let mut base = a_128 % #modulus; + let mut exp = b_128; + while exp > 0 { + if exp & 1 == 1 { + result = mod_add(result, base); + } + base = mod_add(base, base); + exp >>= 1; + } + result as Self::T + } + }; + } + + fn sum_of_products( + a: &[SmallFp; T], + b: &[SmallFp; T],) -> SmallFp { + let mut acc = SmallFp::new(0 as Self::T); + for (x, y) in a.iter().zip(b.iter()) { + let mut prod = *x; + Self::mul_assign(&mut prod, y); + Self::add_assign(&mut acc, &prod); + } + acc + } + + fn square_in_place(a: &mut SmallFp) { + let tmp = *a; + Self::mul_assign(a, &tmp); + } + + fn inverse(a: &SmallFp) -> Option> { + if a.value == 0 { + return None; + } + let mut base = *a; + let mut exp = Self::MODULUS - 2; + let mut acc = Self::ONE; + while exp > 0 { + if (exp & 1) == 1 { + Self::mul_assign(&mut acc, &base); + } + let mut sq = base; + Self::mul_assign(&mut sq, &base); + base = sq; + exp >>= 1; + } + Some(acc) + } + + #from_bigint_impl + + #into_bigint_impl + } +} + +pub(crate) fn new() -> proc_macro2::TokenStream { + quote! { + pub fn new(value: ::T) -> SmallFp { + SmallFp::new(value % ::MODULUS) + } + } +} diff --git a/ff-macros/src/small_fp/utils.rs b/ff-macros/src/small_fp/utils.rs new file mode 100644 index 000000000..39fe370f5 --- /dev/null +++ b/ff-macros/src/small_fp/utils.rs @@ -0,0 +1,186 @@ +use super::*; + +// Compute the largest integer `s` such that `N - 1 = 2**s * t` for odd `t`. +pub(crate) const fn compute_two_adicity(modulus: u128) -> u32 { + assert!(modulus % 2 == 1, "Modulus must be odd"); + assert!(modulus > 1, "Modulus must be greater than 1"); + + let mut n_minus_1 = modulus - 1; + let mut two_adicity = 0; + + while n_minus_1 % 2 == 0 { + n_minus_1 /= 2; + two_adicity += 1; + } + two_adicity +} + +const fn mod_add_const(x: u128, y: u128, modulus: u128) -> u128 { + if x >= modulus - y { + x - (modulus - y) + } else { + x + y + } +} + +pub(crate) const fn mod_mul_const(a: u128, b: u128, modulus: u128) -> u128 { + match a.overflowing_mul(b) { + (val, false) => val % modulus, + (_, true) => { + let mut result = 0u128; + let mut base = a % modulus; + let mut exp = b; + + while exp > 0 { + if exp & 1 == 1 { + result = mod_add_const(result, base, modulus); + } + base = mod_add_const(base, base, modulus); + exp >>= 1; + } + result + }, + } +} + +const fn pow_mod_const(mut base: u128, mut exp: u128, modulus: u128) -> u128 { + let mut result = 1; + base %= modulus; + while exp > 0 { + if exp % 2 == 1 { + result = mod_mul_const(result, base, modulus); + } + base = mod_mul_const(base, base, modulus); + exp /= 2; + } + result +} + +pub(crate) const fn compute_two_adic_root_of_unity( + modulus: u128, + two_adicity: u32, + generator: u128, +) -> u128 { + let exp = (modulus - 1) >> two_adicity; + let base = generator % modulus; + pow_mod_const(base, exp, modulus) +} + +// Finds smallest quadratic non-residue by using Euler's criterion +pub(crate) const fn find_quadratic_non_residue(modulus: u128) -> u128 { + let exponent = (modulus - 1) / 2; + let mut z = 2; + loop { + let legendre = pow_mod_const(z, exponent, modulus); + if legendre == modulus - 1 { + return z; + } + z += 1; + } +} + +pub(crate) fn generate_bigint_casts( + modulus: u128, +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + ( + quote! { + fn from_bigint(a: BigInt<2>) -> Option> { + let val = (a.0[0] as u128) + ((a.0[1] as u128) << 64); + if val > Self::MODULUS_128 { + None + } else { + let reduced_val = val % #modulus; + Some(SmallFp::new(reduced_val as Self::T)) + } + } + }, + quote! { + fn into_bigint(a: SmallFp) -> BigInt<2> { + let val = a.value as u128; + let lo = val as u64; + let hi = (val >> 64) as u64; + ark_ff::BigInt([lo, hi]) + } + }, + ) +} + +pub(crate) fn generate_montgomery_bigint_casts( + modulus: u128, + _k_bits: u32, + r_mod_n: u128, +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let r2 = mod_mul_const(r_mod_n, r_mod_n, modulus); + ( + quote! { + fn from_bigint(a: BigInt<2>) -> Option> { + let val = (a.0[0] as u128) + ((a.0[1] as u128) << 64); + if val > Self::MODULUS_128 { + None + } else { + let reduced_val = val % #modulus; + let mut tmp = SmallFp::new(reduced_val as Self::T); + let r2_elem = SmallFp::new(#r2 as Self::T); + ::mul_assign(&mut tmp, &r2_elem); + Some(tmp) + } + } + }, + quote! { + fn into_bigint(a: SmallFp) -> BigInt<2> { + let mut tmp = a; + let one = SmallFp::new(1 as Self::T); + ::mul_assign(&mut tmp, &one); + let val = tmp.value as u128; + let lo = val as u64; + let hi = (val >> 64) as u64; + ark_ff::BigInt([lo, hi]) + } + }, + ) +} + +pub(crate) fn generate_sqrt_precomputation( + modulus: u128, + two_adicity: u32, + r_mod_n: Option, +) -> proc_macro2::TokenStream { + if modulus % 4 == 3 { + let modulus_plus_one_div_four = (modulus + 1) / 4; + let lo = modulus_plus_one_div_four as u64; + let hi = (modulus_plus_one_div_four >> 64) as u64; + + quote! { + // Case3Mod4 square root precomputation + const SQRT_PRECOMP: Option>> = { + const MODULUS_PLUS_ONE_DIV_FOUR: [u64; 2] = [#lo, #hi]; + Some(SqrtPrecomputation::Case3Mod4 { + modulus_plus_one_div_four: &MODULUS_PLUS_ONE_DIV_FOUR, + }) + }; + } + } else { + let trace = (modulus - 1) >> two_adicity; + let trace_minus_one_div_two = trace / 2; + let lo = trace_minus_one_div_two as u64; + let hi = (trace_minus_one_div_two >> 64) as u64; + let qnr = find_quadratic_non_residue(modulus); + let mut qnr_to_trace = pow_mod_const(qnr, trace, modulus); + + if r_mod_n.is_some() { + qnr_to_trace = mod_mul_const(qnr_to_trace, r_mod_n.unwrap(), modulus); + } + + quote! { + // TonelliShanks square root precomputation + const SQRT_PRECOMP: Option>> = { + const TRACE_MINUS_ONE_DIV_TWO: [u64; 2] = [#lo, #hi]; + Some(SqrtPrecomputation::TonelliShanks { + two_adicity: #two_adicity, + quadratic_nonresidue_to_trace: SmallFp::new(#qnr_to_trace as Self::T), + trace_of_modulus_minus_one_div_two: &TRACE_MINUS_ONE_DIV_TWO, + }) + }; + } + } +} diff --git a/ff/Cargo.toml b/ff/Cargo.toml index 30323fdaa..31965cb92 100644 --- a/ff/Cargo.toml +++ b/ff/Cargo.toml @@ -42,4 +42,4 @@ hex.workspace = true default = [] std = [ "ark-std/std", "ark-serialize/std" ] parallel = [ "std", "rayon", "ark-std/parallel", "ark-serialize/parallel" ] -asm = [] +asm = [] \ No newline at end of file