diff --git a/module-lattice/Cargo.toml b/module-lattice/Cargo.toml index 2d6e2a5..85bf7f5 100644 --- a/module-lattice/Cargo.toml +++ b/module-lattice/Cargo.toml @@ -29,3 +29,6 @@ getrandom = { version = "0.4.0-rc.1", features = ["sys_rng"] } [features] subtle = ["dep:subtle", "array/subtle"] zeroize = ["array/zeroize", "dep:zeroize"] + +[lints] +workspace = true diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index c1fb4f5..547502b 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -10,24 +10,35 @@ use subtle::{Choice, ConstantTimeEq}; #[cfg(feature = "zeroize")] use zeroize::Zeroize; +/// Finite field with efficient modular reduction for lattice-based cryptography. pub trait Field: Copy + Default + Debug + PartialEq { + /// Base integer type used to represent field elements type Int: PrimInt + Default + Debug + From + Into + Into + Truncate; + /// Double-width integer type used for intermediate computations. type Long: PrimInt + From; + /// Quadruple-width integer type used for Barrett reduction. type LongLong: PrimInt; + /// Field modulus. const Q: Self::Int; + /// Field modulus as [`Self::Long`]. const QL: Self::Long; + /// Field modulus as [`Self::LongLong`]. const QLL: Self::LongLong; + /// Bit shift used in Barrett reduction. const BARRETT_SHIFT: usize; + /// Precomputed multiplier for Barrett reduction. const BARRETT_MULTIPLIER: Self::LongLong; + /// Reduce a value that's already close to the modulus range. fn small_reduce(x: Self::Int) -> Self::Int; + /// Reduce a wider value to a field element using Barrett reduction. fn barrett_reduce(x: Self::Long) -> Self::Int; } -/// The `define_field` macro creates a zero-sized struct and an implementation of the Field trait -/// for that struct. The caller must specify: +/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`] +/// trait for that struct. The caller must specify: /// /// * `$field`: The name of the zero-sized struct to be created /// * `$q`: The prime number that defines the field. @@ -39,6 +50,10 @@ pub trait Field: Copy + Default + Debug + PartialEq { #[macro_export] macro_rules! define_field { ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => { + $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field"); + }; + ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => { + #[doc = $doc] #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)] pub struct $field; @@ -71,15 +86,19 @@ macro_rules! define_field { }; } -/// An `Elem` is a member of the specified prime-order field. Elements can be added, -/// subtracted, multiplied, and negated, and the overloaded operators will ensure both that the -/// integer values remain in the field, and that the reductions are done efficiently. For -/// addition and subtraction, a simple conditional subtraction is used; for multiplication, +/// An [`Elem`] is a member of the specified prime-order field. +/// +/// Elements can be added, subtracted, multiplied, and negated, and the overloaded operators will +/// ensure both that the integer values remain in the field, and that the reductions are done +/// efficiently. +/// +/// For addition and subtraction, a simple conditional subtraction is used; for multiplication, /// Barrett reduction. #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)] pub struct Elem(pub F::Int); impl Elem { + /// Create a new field element. pub const fn new(x: F::Int) -> Self { Self(x) } @@ -141,12 +160,14 @@ impl Mul> for Elem { } /// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials -/// over the finite field with prime order `q`. Polynomials can be added, subtracted, negated, -/// and multiplied by field elements. We do not define multiplication of polynomials here. +/// over the finite field with prime order `q`. +/// +/// Polynomials can be added, subtracted, negated, and multiplied by field elements. #[derive(Clone, Copy, Default, Debug, PartialEq)] pub struct Polynomial(pub Array, U256>); impl Polynomial { + /// Create a new polynomial. pub const fn new(x: Array, U256>) -> Self { Self(x) } @@ -206,12 +227,14 @@ impl Neg for &Polynomial { } } -/// A `Vector` is a vector of polynomials from `R_q` of length `K`. Vectors can be -/// added, subtracted, negated, and multiplied by field elements. +/// A `Vector` is a vector of polynomials from `R_q` of length `K`. +/// +/// Vectors can be added, subtracted, negated, and multiplied by field elements. #[derive(Clone, Default, Debug, PartialEq)] pub struct Vector(pub Array, K>); impl Vector { + /// Create a new vector. pub const fn new(x: Array, K>) -> Self { Self(x) } @@ -278,14 +301,19 @@ impl Neg for &Vector { } /// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field -/// elements. NTT polynomials can be added and -/// subtracted, negated, and multiplied by scalars. -/// We do not define multiplication of NTT polynomials here. We also do not define the -/// mappings between normal polynomials and NTT polynomials (i.e., between `R_q` and `T_q`). +/// elements. +/// +/// NTT polynomials can be added and subtracted, negated, and multiplied by scalars. +/// We do not define multiplication of NTT polynomials here: that is defined by the downstream +/// crate using the [`MultiplyNtt`] trait. +/// +/// We also do not define the mappings between normal polynomials and NTT polynomials (i.e., between +/// `R_q` and `T_q`). #[derive(Clone, Default, Debug, Eq, PartialEq)] pub struct NttPolynomial(pub Array, U256>); impl NttPolynomial { + /// Create a new NTT polynomial. pub const fn new(x: Array, U256>) -> Self { Self(x) } @@ -340,6 +368,7 @@ where /// Perform multiplication in the NTT domain. pub trait MultiplyNtt: Field { + /// Multiply two NTT polynomials. fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial; } @@ -383,14 +412,16 @@ where } } -/// An `NttVector` is a vector of polynomials from `T_q` of length `K`. NTT vectors can be -/// added and subtracted. If multiplication is defined for NTT polynomials, then NTT vectors -/// can be multiplied by NTT polynomials, and "multiplied" with each other to produce a dot -/// product. +/// An [`NttVector`] is a vector of polynomials from `T_q` of length `K`. +/// +/// NTT vectors can be added and subtracted. If multiplication is defined for NTT polynomials, then +/// NTT vectors can be multiplied by NTT polynomials, and "multiplied" with each other to produce a +/// dot product. #[derive(Clone, Default, Debug, Eq, PartialEq)] pub struct NttVector(pub Array, K>); impl NttVector { + /// Create a new NTT vector. pub const fn new(x: Array, K>) -> Self { Self(x) } @@ -470,14 +501,18 @@ where } } -/// A K x L matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that -/// multiplying on the right just requires iteration. Multiplication on the right by vectors -/// is the only defined operation, and is only defined when multiplication of NTT polynomials -/// is defined. +/// A `K x L` matrix of NTT-domain polynomials. +/// +/// Each vector represents a row of the matrix, so that multiplying on the right just requires +/// iteration. +/// +/// Multiplication on the right by vectors is the only defined operation, and is only defined when +/// multiplication of NTT polynomials is defined. #[derive(Clone, Default, Debug, PartialEq)] pub struct NttMatrix(pub Array, K>); impl NttMatrix { + /// Create a new NTT matrix. pub const fn new(x: Array, K>) -> Self { Self(x) } diff --git a/module-lattice/src/encoding.rs b/module-lattice/src/encoding.rs index 1bac24c..d649a96 100644 --- a/module-lattice/src/encoding.rs +++ b/module-lattice/src/encoding.rs @@ -16,14 +16,19 @@ impl ArraySize for T where T: array::ArraySize + PartialEq + Debug {} /// An integer that can describe encoded polynomials. pub trait EncodingSize: ArraySize { + /// Size of an encoded polynomial. type EncodedPolynomialSize: ArraySize; + /// Value step. type ValueStep: ArraySize; + /// Byte step. type ByteStep: ArraySize; } type EncodingUnit = Quot, Gcf>; +/// Size of an encoded polynomial. pub type EncodedPolynomialSize = ::EncodedPolynomialSize; +/// Encoded polynomial. pub type EncodedPolynomial = Array>; impl EncodingSize for D @@ -40,6 +45,7 @@ where type ByteStep = Quot, U8>; } +/// Decoded value. pub type DecodedValue = Array, U256>; /// An integer that can describe encoded vectors. @@ -47,13 +53,18 @@ pub trait VectorEncodingSize: EncodingSize where K: ArraySize, { + /// Size of an encoded vector. type EncodedVectorSize: ArraySize; + /// Flatten encoded polynomial array into encoded vector. fn flatten(polys: Array, K>) -> EncodedVector; + /// Unflatten encoded vector into encoded polynomial array. fn unflatten(vec: &EncodedVector) -> Array<&EncodedPolynomial, K>; } +/// Size of an encoded vector. pub type EncodedVectorSize = >::EncodedVectorSize; +/// Encoded vector. pub type EncodedVector = Array>; impl VectorEncodingSize for D @@ -75,8 +86,8 @@ where } } -// FIPS 203: Algorithm 4 ByteEncode_d -// FIPS 204: Algorithm 16 SimpleBitPack +/// FIPS 203: Algorithm 4 `ByteEncode_d`. +/// FIPS 204: Algorithm 16 `SimpleBitPack`. pub fn byte_encode(vals: &DecodedValue) -> EncodedPolynomial { let val_step = D::ValueStep::USIZE; let byte_step = D::ByteStep::USIZE; @@ -99,8 +110,8 @@ pub fn byte_encode(vals: &DecodedValue) -> Encoded bytes } -// FIPS 203: Algorithm 5 ByteDecode_d(F) -// FIPS 204: Algorithm 18 SimpleBitUnpack +/// FIPS 203: Algorithm 5 `ByteDecode_d(F)` +/// FIPS 204: Algorithm 18 `SimpleBitUnpack` pub fn byte_decode(bytes: &EncodedPolynomial) -> DecodedValue { let val_step = D::ValueStep::USIZE; let byte_step = D::ByteStep::USIZE; @@ -129,9 +140,13 @@ pub fn byte_decode(bytes: &EncodedPolynomial) -> D vals } +/// Encoding trait. pub trait Encode { + /// Size of the encoded object. type EncodedSize: ArraySize; + /// Encode object. fn encode(&self) -> Array; + /// Decode object. fn decode(enc: &Array) -> Self; } diff --git a/module-lattice/src/utils.rs b/module-lattice/src/utils.rs index 7afe59f..bfcccc1 100644 --- a/module-lattice/src/utils.rs +++ b/module-lattice/src/utils.rs @@ -10,6 +10,7 @@ use core::{ /// Safely truncate an unsigned integer value to shorter representation pub trait Truncate { + /// Truncate value to the width of `Self`. fn truncate(x: T) -> Self; } @@ -39,10 +40,12 @@ define_truncate!(u128, u32); define_truncate!(usize, u8); define_truncate!(usize, u16); -/// Defines a sequence of sequences that can be merged into a bigger overall seequence +/// Defines a sequence of sequences that can be merged into a bigger overall sequence. pub trait Flatten { + /// Size of the output array. type OutputSize: ArraySize; + /// Flatten array. fn flatten(self) -> Array; } @@ -54,58 +57,69 @@ where { type OutputSize = Prod; - // This is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed to be - // safe by the Rust memory layout of these types. fn flatten(self) -> Array { let whole = ManuallyDrop::new(self); - unsafe { ptr::read(whole.as_ptr().cast()) } + + // SAFETY: this is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed + // to be safe by the Rust memory layout of these types. + #[allow(unsafe_code)] + unsafe { + ptr::read(whole.as_ptr().cast()) + } } } -/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size +/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size. pub trait Unflatten where M: ArraySize, { + /// Part of the array we're decomposing into. type Part; + /// Unflatten array into `Self::Part` chunks. fn unflatten(self) -> Array; } impl Unflatten for Array where - T: Default, N: ArraySize + Div + Rem, M: ArraySize, Quot: ArraySize, { type Part = Array>; - // This requires some unsafeness, but it is the same as what is done in Array::split. - // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to - // be safe by the Rust memory layout of these types. fn unflatten(self) -> Array { let part_size = Quot::::USIZE; let whole = ManuallyDrop::new(self); - Array::from_fn(|i| unsafe { ptr::read(whole.as_ptr().add(i * part_size).cast()) }) + + // SAFETY: this is doing the same thing as what is done in `Array::split`. + // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to + // be safe by the Rust memory layout of these types. + #[allow(unsafe_code)] + Array::from_fn(|i| unsafe { + let offset = i.checked_mul(part_size).expect("overflow"); + ptr::read(whole.as_ptr().add(offset).cast()) + }) } } impl<'a, T, N, M> Unflatten for &'a Array where - T: Default, N: ArraySize + Div + Rem, M: ArraySize, Quot: ArraySize, { type Part = &'a Array>; - // This requires some unsafeness, but it is the same as what is done in Array::split. - // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to - // be safe by the Rust memory layout of these types. fn unflatten(self) -> Array { let part_size = Quot::::USIZE; let mut ptr: *const T = self.as_ptr(); + + // SAFETY: this is doing the same thing as what is done in `Array::split`. + // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to + // be safe by the Rust memory layout of these types. + #[allow(unsafe_code)] Array::from_fn(|_i| unsafe { let part = &*(ptr.cast()); ptr = ptr.add(part_size); diff --git a/module-lattice/tests/algebra.rs b/module-lattice/tests/algebra.rs index 1897447..2a06a20 100644 --- a/module-lattice/tests/algebra.rs +++ b/module-lattice/tests/algebra.rs @@ -1,5 +1,8 @@ //! Tests for the `algebra` module. +#![allow(clippy::cast_possible_truncation, reason = "tests")] +#![allow(clippy::integer_division_remainder_used, reason = "tests")] + use array::typenum::U2; use module_lattice::algebra::{ Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector, diff --git a/module-lattice/tests/encode.rs b/module-lattice/tests/encode.rs index 7173d51..9bdfd11 100644 --- a/module-lattice/tests/encode.rs +++ b/module-lattice/tests/encode.rs @@ -1,6 +1,7 @@ //! Tests for the `encode` module. -#![allow(clippy::integer_division_remainder_used)] +#![allow(clippy::cast_possible_truncation, reason = "tests")] +#![allow(clippy::integer_division_remainder_used, reason = "tests")] use array::sizes::U3; use array::typenum::{Mod, Zero}; @@ -8,6 +9,7 @@ use array::{ Array, sizes::{U1, U2, U4, U5, U6, U8, U10, U11, U12, U256}, }; +use core::{fmt::Debug, ops::Rem}; use getrandom::{ SysRng, rand_core::{Rng, UnwrapErr}, @@ -17,8 +19,6 @@ use module_lattice::{ algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector}, encoding::{ArraySize, Encode, EncodedPolynomial, EncodingSize, byte_decode, byte_encode}, }; -use std::fmt::Debug; -use std::ops::Rem; // Field used by ML-KEM. module_lattice::define_field!(KyberField, u16, u32, u64, 3329); @@ -65,7 +65,7 @@ where let decoded = Array::::from_fn(|_| (rng.next_u32() & 0xFFFF) as Int); let m = match D::USIZE { 12 => KyberField::Q, - d => (1 as Int) << d, + d => 1 << d, }; let decoded = decoded.iter().map(|x| Elem::new(x % m)).collect();