Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

### Bugfixes

- [\#1082](https://github.com/arkworks-rs/algebra/pull/1082) (`ark-ff`) Fix `SmallFp::from_random_bytes` / `from_be_bytes_mod_order` silently producing incorrect field elements by treating plaintext bytes as Montgomery-encoded.
- (`ark-ff`) Fix `SmallFp::from_random_bytes` / `from_be_bytes_mod_order` silently producing incorrect field elements by treating plaintext bytes as Montgomery-encoded.

## v0.5.0

Expand Down
232 changes: 208 additions & 24 deletions ff-macros/src/small_fp/montgomery_backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use crate::small_fp::utils::{
compute_two_adic_root_of_unity, compute_two_adicity, generate_montgomery_bigint_casts,
generate_sqrt_precomputation, mod_mul_const,
generate_sqrt_precomputation, mod_mul_const, pow_mod_const,
};

pub(crate) fn backend_impl(
Expand Down Expand Up @@ -62,6 +62,7 @@ pub(crate) fn backend_impl(

// Generate multiplication implementation based on type
let mul_impl = generate_mul_impl(ty, modulus, k_bits, r_mask, n_prime);
let inverse_impl = generate_inverse_impl(ty, modulus, r_mod_n, r2);

let type_bits = match ty_str.as_str() {
"u8" => 8u32,
Expand Down Expand Up @@ -138,6 +139,8 @@ pub(crate) fn backend_impl(

#mul_impl

#inverse_impl

#[inline(always)]
fn sum_of_products<const T: usize>(
a: &[SmallFp<Self>; T],
Expand Down Expand Up @@ -174,29 +177,6 @@ pub(crate) fn backend_impl(
Self::mul_assign(a, &tmp);
}

fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}

let mut result = Self::ONE;
let mut base = *a;
let mut exp = Self::MODULUS - 2;

while exp > 0 {
if exp & 1 == 1 {
Self::mul_assign(&mut result, &base);
}

let mut sq = base;
Self::square_in_place(&mut sq);
base = sq;
exp >>= 1;
}

Some(result)
}

#[inline]
fn new(value: Self::T) -> SmallFp<Self> {
let reduced_value = value % Self::MODULUS;
Expand All @@ -212,6 +192,210 @@ pub(crate) fn backend_impl(
}
}

// Generates the inverse function using the binary extended GCD algorithm for u8/u16/u32/u64
// fields, falling back to Fermat's little theorem for u128 fields.
//
// The GCD algorithm runs for NUM_ITERS = 2*FIELD_BITS - 2 iterations of cheap integer ops
// (no modular reduction), returning v ≡ 2^NUM_ITERS · (a·R)^{-1} (mod P). A single
// Montgomery multiplication by the precomputed constant C = R^3 · 2^{-NUM_ITERS} mod P
// then corrects the result to a^{-1}·R mod P (the Montgomery form of the inverse).
fn generate_inverse_impl(
ty: &proc_macro2::TokenStream,
modulus: u128,
r_mod_n: u128,
r2: u128,
) -> proc_macro2::TokenStream {
let ty_str = ty.to_string();

if ty_str == "u128" {
// GCD coefficients would require 256-bit signed integers; use Fermat's little theorem.
return quote! {
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}
let mut result = Self::ONE;
let mut base = *a;
let mut exp = Self::MODULUS - 2;
while exp > 0 {
if exp & 1 == 1 {
Self::mul_assign(&mut result, &base);
}
let mut sq = base;
Self::square_in_place(&mut sq);
base = sq;
exp >>= 1;
}
Some(result)
}
};
}

// FIELD_BITS = bit-length of modulus; NUM_ITERS = 2*FIELD_BITS - 2 ensures convergence.
let field_bits = 128 - modulus.leading_zeros();
let num_iters = 2 * field_bits - 2;

// Correction constant: C = R^3 · 2^{-NUM_ITERS} mod P
// 2^{-1} mod P = (P+1)/2 (P is odd)
// 2^{-NUM_ITERS} mod P = ((P+1)/2)^NUM_ITERS mod P
let half = (modulus + 1) / 2;
let two_neg_iters = pow_mod_const(half, num_iters as u128, modulus);
let r3 = mod_mul_const(r2, r_mod_n, modulus);
let corr = mod_mul_const(r3, two_neg_iters, modulus);

if ty_str == "u64" {
// Two-round binary extended GCD for 64-bit fields (Plonky3 approach).
//
// Split NUM_ITERS into two rounds of half_iters each. (a, b) stay u64 since
// they're always in [0, P). The first (half_iters − 1) iterations per round
// use i64 matrix entries (|entry| ≤ 2^{HALF_ITERS−1} ≤ 2^62, no overflow).
// The final iteration of each round is promoted to i128 to safely handle the
// subtraction/shift that would otherwise produce ±2^63, which is out of i64
// range. Recombination: sum = f11*f00 + g11*f10 in i128 (≤ 2^126 < 2^127).
// Note: the final modular reduction uses u128 to handle P ≈ 2^64 where
// (sum % p) + p can reach ~2P ≈ 2^65, overflowing u64.

let half_iters = num_iters / 2;
let half_iters_i64 = half_iters - 1; // iterations using i64, one fewer per round

quote! {
#[inline]
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}
// Invariant per round after k iters: f0·a0 ≡ aa·2^k (mod P)
// f1·a0 ≡ bb·2^k (mod P)
let mut aa: u64 = a.value;
let mut bb: u64 = Self::MODULUS;

// ---- Round 1: produces (f00, f10) in i128 ----
let (f00, f10): (i128, i128);
{
let (mut f0, mut g0, mut f1, mut g1): (i64, i64, i64, i64) = (1, 0, 0, 1);
// First (half_iters − 1) iterations in i64: |entry| ≤ 2^{half_iters−1} ≤ 2^62
let mut i = 0u32;
while i < #half_iters_i64 {
if aa & 1 != 0 {
if aa < bb {
let t = aa; aa = bb; bb = t;
let t = f0; f0 = f1; f1 = t;
let t = g0; g0 = g1; g1 = t;
}
aa -= bb; aa >>= 1;
f0 -= f1; g0 -= g1;
} else {
aa >>= 1;
}
f1 <<= 1; g1 <<= 1;
i += 1;
}
// Final iteration in i128 to avoid ±2^63 overflow
let (mut f0, mut g0, mut f1, mut g1) =
(f0 as i128, g0 as i128, f1 as i128, g1 as i128);
if aa & 1 != 0 {
if aa < bb {
let t = aa; aa = bb; bb = t;
let t = f0; f0 = f1; f1 = t;
let t = g0; g0 = g1; g1 = t;
}
aa -= bb; aa >>= 1; f0 -= f1; g0 -= g1;
} else {
aa >>= 1;
}
f1 <<= 1; g1 <<= 1;
f00 = f0; f10 = f1;
}

// ---- Round 2: identical structure, produces (f11, g11) ----
let (f11, g11): (i128, i128);
{
let (mut f0, mut g0, mut f1, mut g1): (i64, i64, i64, i64) = (1, 0, 0, 1);
let mut i = 0u32;
while i < #half_iters_i64 {
if aa & 1 != 0 {
if aa < bb {
let t = aa; aa = bb; bb = t;
let t = f0; f0 = f1; f1 = t;
let t = g0; g0 = g1; g1 = t;
}
aa -= bb; aa >>= 1;
f0 -= f1; g0 -= g1;
} else {
aa >>= 1;
}
f1 <<= 1; g1 <<= 1;
i += 1;
}
let (mut f0, mut g0, mut f1, mut g1) =
(f0 as i128, g0 as i128, f1 as i128, g1 as i128);
if aa & 1 != 0 {
if aa < bb {
let t = aa; aa = bb; bb = t;
let t = f0; f0 = f1; f1 = t;
let t = g0; g0 = g1; g1 = t;
}
aa -= bb; aa >>= 1; f0 -= f1; g0 -= g1;
} else {
aa >>= 1;
}
f1 <<= 1; g1 <<= 1;
f11 = f1; g11 = g1;
}

// sum = f11*f00 + g11*f10 ≡ (a·R)^{-1} · 2^{NUM_ITERS} (mod P)
// Each factor ≤ 2^63 in magnitude ⇒ product ≤ 2^126 < i128::MAX.
let sum_raw = f11 * f00 + g11 * f10;
let p = Self::MODULUS as i128;
// Use u128 for intermediate: (sum % p) + p can reach ~2P ≈ 2^65 for
// 64-bit moduli, which overflows u64.
let sum_reduced = ((sum_raw % p) + p) as u128 % Self::MODULUS as u128;
// Multiply by C = R^3 · 2^{-NUM_ITERS} mod P to get a^{-1}·R mod P
let mut result = SmallFp::from_raw(sum_reduced as Self::T);
let corr = SmallFp::from_raw(#corr as Self::T);
Self::mul_assign(&mut result, &corr);
Some(result)
}
}
} else {
// u8, u16, u32: GCD coefficients fit in i64 (|v| ≤ 2^{2*32-2} = 2^62 < 2^63)
quote! {
#[inline]
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}
// Binary extended GCD: v = 2^NUM_ITERS · (a·R)^{-1} mod P
let mut aa: u64 = a.value as u64;
let mut bb: u64 = Self::MODULUS as u64;
let mut u: i64 = 1;
let mut v: i64 = 0;
let mut i = 0u32;
while i < #num_iters {
if aa & 1 != 0 {
if aa < bb {
let tmp_a = aa; aa = bb; bb = tmp_a;
let tmp_u = u; u = v; v = tmp_u;
}
aa -= bb;
u -= v;
}
aa >>= 1;
v <<= 1;
i += 1;
}
let p = Self::MODULUS as i64;
let v_reduced = ((v % p) + p) as u64 % Self::MODULUS as u64;
// Multiply by C = R^3 · 2^{-NUM_ITERS} mod P to get a^{-1}·R mod P
let mut result = SmallFp::from_raw(v_reduced as Self::T);
let corr = SmallFp::from_raw(#corr as Self::T);
Self::mul_assign(&mut result, &corr);
Some(result)
}
}
}
}

// Selects the appropriate multiplication algorithm at compile time:
// if modulus <= u64, multiply by casting to the next largest primitive
// otherwise, multiply in parts to form a 256-bit product
Expand Down
2 changes: 1 addition & 1 deletion ff-macros/src/small_fp/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(crate) const fn mod_mul_const(a: u128, b: u128, modulus: u128) -> u128 {
}
}

const fn pow_mod_const(mut base: u128, mut exp: u128, modulus: u128) -> u128 {
pub(crate) const fn pow_mod_const(mut base: u128, mut exp: u128, modulus: u128) -> u128 {
let mut result = 1;
base %= modulus;
while exp > 0 {
Expand Down
Loading