Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions module-lattice/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ getrandom = { version = "0.4.0-rc.1", features = ["sys_rng"] }
[features]
subtle = ["dep:subtle", "array/subtle"]
zeroize = ["array/zeroize", "dep:zeroize"]

[lints]
workspace = true
79 changes: 57 additions & 22 deletions module-lattice/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,35 @@ use subtle::{Choice, ConstantTimeEq};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;

/// Finite field with efficient modular reduction for lattice-based cryptography.
pub trait Field: Copy + Default + Debug + PartialEq {
/// Base integer type used to represent field elements
type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
/// Double-width integer type used for intermediate computations.
type Long: PrimInt + From<Self::Int>;
/// Quadruple-width integer type used for Barrett reduction.
type LongLong: PrimInt;

/// Field modulus.
const Q: Self::Int;
/// Field modulus as [`Self::Long`].
const QL: Self::Long;
/// Field modulus as [`Self::LongLong`].
const QLL: Self::LongLong;

/// Bit shift used in Barrett reduction.
const BARRETT_SHIFT: usize;
/// Precomputed multiplier for Barrett reduction.
const BARRETT_MULTIPLIER: Self::LongLong;

/// Reduce a value that's already close to the modulus range.
fn small_reduce(x: Self::Int) -> Self::Int;
/// Reduce a wider value to a field element using Barrett reduction.
fn barrett_reduce(x: Self::Long) -> Self::Int;
}

/// The `define_field` macro creates a zero-sized struct and an implementation of the Field trait
/// for that struct. The caller must specify:
/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
/// trait for that struct. The caller must specify:
///
/// * `$field`: The name of the zero-sized struct to be created
/// * `$q`: The prime number that defines the field.
Expand All @@ -39,6 +50,10 @@ pub trait Field: Copy + Default + Debug + PartialEq {
#[macro_export]
macro_rules! define_field {
($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
$crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
};
($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
#[doc = $doc]
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
pub struct $field;

Expand Down Expand Up @@ -71,15 +86,19 @@ macro_rules! define_field {
};
}

/// An `Elem` is a member of the specified prime-order field. Elements can be added,
/// subtracted, multiplied, and negated, and the overloaded operators will ensure both that the
/// integer values remain in the field, and that the reductions are done efficiently. For
/// addition and subtraction, a simple conditional subtraction is used; for multiplication,
/// An [`Elem`] is a member of the specified prime-order field.
///
/// Elements can be added, subtracted, multiplied, and negated, and the overloaded operators will
/// ensure both that the integer values remain in the field, and that the reductions are done
/// efficiently.
///
/// For addition and subtraction, a simple conditional subtraction is used; for multiplication,
/// Barrett reduction.
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
pub struct Elem<F: Field>(pub F::Int);

impl<F: Field> Elem<F> {
/// Create a new field element.
pub const fn new(x: F::Int) -> Self {
Self(x)
}
Expand Down Expand Up @@ -141,12 +160,14 @@ impl<F: Field> Mul<Elem<F>> for Elem<F> {
}

/// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials
/// over the finite field with prime order `q`. Polynomials can be added, subtracted, negated,
/// and multiplied by field elements. We do not define multiplication of polynomials here.
/// over the finite field with prime order `q`.
///
/// Polynomials can be added, subtracted, negated, and multiplied by field elements.
#[derive(Clone, Copy, Default, Debug, PartialEq)]
pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);

impl<F: Field> Polynomial<F> {
/// Create a new polynomial.
pub const fn new(x: Array<Elem<F>, U256>) -> Self {
Self(x)
}
Expand Down Expand Up @@ -206,12 +227,14 @@ impl<F: Field> Neg for &Polynomial<F> {
}
}

/// A `Vector` is a vector of polynomials from `R_q` of length `K`. Vectors can be
/// added, subtracted, negated, and multiplied by field elements.
/// A `Vector` is a vector of polynomials from `R_q` of length `K`.
///
/// Vectors can be added, subtracted, negated, and multiplied by field elements.
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);

impl<F: Field, K: ArraySize> Vector<F, K> {
/// Create a new vector.
pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
Self(x)
}
Expand Down Expand Up @@ -278,14 +301,19 @@ impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
}

/// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field
/// elements. NTT polynomials can be added and
/// subtracted, negated, and multiplied by scalars.
/// We do not define multiplication of NTT polynomials here. We also do not define the
/// mappings between normal polynomials and NTT polynomials (i.e., between `R_q` and `T_q`).
/// elements.
///
/// NTT polynomials can be added and subtracted, negated, and multiplied by scalars.
/// We do not define multiplication of NTT polynomials here: that is defined by the downstream
/// crate using the [`MultiplyNtt`] trait.
///
/// We also do not define the mappings between normal polynomials and NTT polynomials (i.e., between
/// `R_q` and `T_q`).
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);

impl<F: Field> NttPolynomial<F> {
/// Create a new NTT polynomial.
pub const fn new(x: Array<Elem<F>, U256>) -> Self {
Self(x)
}
Expand Down Expand Up @@ -340,6 +368,7 @@ where

/// Perform multiplication in the NTT domain.
pub trait MultiplyNtt: Field {
/// Multiply two NTT polynomials.
fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
}

Expand Down Expand Up @@ -383,14 +412,16 @@ where
}
}

/// An `NttVector` is a vector of polynomials from `T_q` of length `K`. NTT vectors can be
/// added and subtracted. If multiplication is defined for NTT polynomials, then NTT vectors
/// can be multiplied by NTT polynomials, and "multiplied" with each other to produce a dot
/// product.
/// An [`NttVector`] is a vector of polynomials from `T_q` of length `K`.
///
/// NTT vectors can be added and subtracted. If multiplication is defined for NTT polynomials, then
/// NTT vectors can be multiplied by NTT polynomials, and "multiplied" with each other to produce a
/// dot product.
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);

impl<F: Field, K: ArraySize> NttVector<F, K> {
/// Create a new NTT vector.
pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
Self(x)
}
Expand Down Expand Up @@ -470,14 +501,18 @@ where
}
}

/// A K x L matrix of NTT-domain polynomials. Each vector represents a row of the matrix, so that
/// multiplying on the right just requires iteration. Multiplication on the right by vectors
/// is the only defined operation, and is only defined when multiplication of NTT polynomials
/// is defined.
/// A `K x L` matrix of NTT-domain polynomials.
///
/// Each vector represents a row of the matrix, so that multiplying on the right just requires
/// iteration.
///
/// Multiplication on the right by vectors is the only defined operation, and is only defined when
/// multiplication of NTT polynomials is defined.
#[derive(Clone, Default, Debug, PartialEq)]
pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);

impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
/// Create a new NTT matrix.
pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
Self(x)
}
Expand Down
23 changes: 19 additions & 4 deletions module-lattice/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ impl<T> ArraySize for T where T: array::ArraySize + PartialEq + Debug {}

/// An integer that can describe encoded polynomials.
pub trait EncodingSize: ArraySize {
/// Size of an encoded polynomial.
type EncodedPolynomialSize: ArraySize;
/// Value step.
type ValueStep: ArraySize;
/// Byte step.
type ByteStep: ArraySize;
}

type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;

/// Size of an encoded polynomial.
pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
/// Encoded polynomial.
pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;

impl<D> EncodingSize for D
Expand All @@ -40,20 +45,26 @@ where
type ByteStep = Quot<EncodingUnit<D>, U8>;
}

/// Decoded value.
pub type DecodedValue<F> = Array<Elem<F>, U256>;

/// An integer that can describe encoded vectors.
pub trait VectorEncodingSize<K>: EncodingSize
where
K: ArraySize,
{
/// Size of an encoded vector.
type EncodedVectorSize: ArraySize;

/// Flatten encoded polynomial array into encoded vector.
fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K>;
/// Unflatten encoded vector into encoded polynomial array.
fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
}

/// Size of an encoded vector.
pub type EncodedVectorSize<D, K> = <D as VectorEncodingSize<K>>::EncodedVectorSize;
/// Encoded vector.
pub type EncodedVector<D, K> = Array<u8, EncodedVectorSize<D, K>>;

impl<D, K> VectorEncodingSize<K> for D
Expand All @@ -75,8 +86,8 @@ where
}
}

// FIPS 203: Algorithm 4 ByteEncode_d
// FIPS 204: Algorithm 16 SimpleBitPack
/// FIPS 203: Algorithm 4 `ByteEncode_d`.
/// FIPS 204: Algorithm 16 `SimpleBitPack`.
pub fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> EncodedPolynomial<D> {
let val_step = D::ValueStep::USIZE;
let byte_step = D::ByteStep::USIZE;
Expand All @@ -99,8 +110,8 @@ pub fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> Encoded
bytes
}

// FIPS 203: Algorithm 5 ByteDecode_d(F)
// FIPS 204: Algorithm 18 SimpleBitUnpack
/// FIPS 203: Algorithm 5 `ByteDecode_d(F)`
/// FIPS 204: Algorithm 18 `SimpleBitUnpack`
pub fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue<F> {
let val_step = D::ValueStep::USIZE;
let byte_step = D::ByteStep::USIZE;
Expand Down Expand Up @@ -129,9 +140,13 @@ pub fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> D
vals
}

/// Encoding trait.
pub trait Encode<D: EncodingSize> {
/// Size of the encoded object.
type EncodedSize: ArraySize;
/// Encode object.
fn encode(&self) -> Array<u8, Self::EncodedSize>;
/// Decode object.
fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
}

Expand Down
42 changes: 28 additions & 14 deletions module-lattice/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use core::{

/// Safely truncate an unsigned integer value to shorter representation
pub trait Truncate<T> {
/// Truncate value to the width of `Self`.
fn truncate(x: T) -> Self;
}

Expand Down Expand Up @@ -39,10 +40,12 @@ define_truncate!(u128, u32);
define_truncate!(usize, u8);
define_truncate!(usize, u16);

/// Defines a sequence of sequences that can be merged into a bigger overall seequence
/// Defines a sequence of sequences that can be merged into a bigger overall sequence.
pub trait Flatten<T, M: ArraySize> {
/// Size of the output array.
type OutputSize: ArraySize;

/// Flatten array.
fn flatten(self) -> Array<T, Self::OutputSize>;
}

Expand All @@ -54,58 +57,69 @@ where
{
type OutputSize = Prod<M, N>;

// This is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed to be
// safe by the Rust memory layout of these types.
fn flatten(self) -> Array<T, Self::OutputSize> {
let whole = ManuallyDrop::new(self);
unsafe { ptr::read(whole.as_ptr().cast()) }

// SAFETY: this is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed
// to be safe by the Rust memory layout of these types.
#[allow(unsafe_code)]
unsafe {
ptr::read(whole.as_ptr().cast())
}
}
}

/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size
/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size.
pub trait Unflatten<M>
where
M: ArraySize,
{
/// Part of the array we're decomposing into.
type Part;

/// Unflatten array into `Self::Part` chunks.
fn unflatten(self) -> Array<Self::Part, M>;
}

impl<T, N, M> Unflatten<M> for Array<T, N>
where
T: Default,
N: ArraySize + Div<M> + Rem<M, Output = U0>,
M: ArraySize,
Quot<N, M>: ArraySize,
{
type Part = Array<T, Quot<N, M>>;

// This requires some unsafeness, but it is the same as what is done in Array::split.
// Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
// be safe by the Rust memory layout of these types.
fn unflatten(self) -> Array<Self::Part, M> {
let part_size = Quot::<N, M>::USIZE;
let whole = ManuallyDrop::new(self);
Array::from_fn(|i| unsafe { ptr::read(whole.as_ptr().add(i * part_size).cast()) })

// SAFETY: this is doing the same thing as what is done in `Array::split`.
// Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
// be safe by the Rust memory layout of these types.
#[allow(unsafe_code)]
Array::from_fn(|i| unsafe {
let offset = i.checked_mul(part_size).expect("overflow");
ptr::read(whole.as_ptr().add(offset).cast())
})
}
}

impl<'a, T, N, M> Unflatten<M> for &'a Array<T, N>
where
T: Default,
N: ArraySize + Div<M> + Rem<M, Output = U0>,
M: ArraySize,
Quot<N, M>: ArraySize,
{
type Part = &'a Array<T, Quot<N, M>>;

// This requires some unsafeness, but it is the same as what is done in Array::split.
// Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
// be safe by the Rust memory layout of these types.
fn unflatten(self) -> Array<Self::Part, M> {
let part_size = Quot::<N, M>::USIZE;
let mut ptr: *const T = self.as_ptr();

// SAFETY: this is doing the same thing as what is done in `Array::split`.
// Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
// be safe by the Rust memory layout of these types.
#[allow(unsafe_code)]
Array::from_fn(|_i| unsafe {
let part = &*(ptr.cast());
ptr = ptr.add(part_size);
Expand Down
3 changes: 3 additions & 0 deletions module-lattice/tests/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Tests for the `algebra` module.

#![allow(clippy::cast_possible_truncation, reason = "tests")]
#![allow(clippy::integer_division_remainder_used, reason = "tests")]

use array::typenum::U2;
use module_lattice::algebra::{
Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector,
Expand Down
Loading