Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ed448-goldilocks/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ impl FieldElement {
mod tests {
use super::*;
use elliptic_curve::consts::U32;
use hash2curve::{ExpandMsg, ExpandMsgXof, Expander};
use hash2curve::{ExpandMsg, ExpandMsgXof};
use hex_literal::hex;
use sha3::Shake256;

Expand All @@ -463,16 +463,16 @@ mod tests {
(84 * 2).try_into().unwrap(),
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
let mut data = Array::<u8, U84>::from_iter(expander.by_ref().take(84));
// TODO: This should be `Curve448FieldElement`.
let u0 = Ed448FieldElement::from_okm(&data).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);
data = Array::<u8, U84>::from_iter(expander);

// TODO: This should be `Curve448FieldElement`.
let u1 = Ed448FieldElement::from_okm(&data).0;
assert_eq!(u1.to_bytes(), e_u1);
Expand All @@ -497,15 +497,14 @@ mod tests {
(84 * 2).try_into().unwrap(),
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
let mut data = Array::<u8, U84>::from_iter(expander.by_ref().take(84));
let u0 = Ed448FieldElement::from_okm(&data).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);
data = Array::<u8, U84>::from_iter(expander.by_ref());
let u1 = Ed448FieldElement::from_okm(&data).0;
assert_eq!(u1.to_bytes(), e_u1);
}
Expand Down
15 changes: 9 additions & 6 deletions hash2curve/src/group_digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pub trait GroupDigest: MapToCurve {
///
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
fn hash_from_bytes<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<ProjectivePoint<Self>>
fn hash_from_bytes<'dst, X>(msg: &[&[u8]], dst: &'dst [&[u8]]) -> Result<ProjectivePoint<Self>>
where
X: ExpandMsg<Self::K>,
X: ExpandMsg<'dst, Self::K>,
{
let [u0, u1] = hash_to_field::<2, X, _, Self::FieldElement>(msg, dst)?;
let q0 = Self::map_to_curve(u0);
Expand Down Expand Up @@ -62,9 +62,12 @@ pub trait GroupDigest: MapToCurve {
///
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
fn encode_from_bytes<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<ProjectivePoint<Self>>
fn encode_from_bytes<'dst, X>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
) -> Result<ProjectivePoint<Self>>
where
X: ExpandMsg<Self::K>,
X: ExpandMsg<'dst, Self::K>,
{
let [u] = hash_to_field::<1, X, _, Self::FieldElement>(msg, dst)?;
let q0 = Self::map_to_curve(u);
Expand All @@ -85,9 +88,9 @@ pub trait GroupDigest: MapToCurve {
///
/// [`ExpandMsgXmd`]: crate::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::ExpandMsgXof
fn hash_to_scalar<X>(msg: &[&[u8]], dst: &[&[u8]]) -> Result<Self::Scalar>
fn hash_to_scalar<'dst, X>(msg: &[&[u8]], dst: &'dst [&[u8]]) -> Result<Self::Scalar>
where
X: ExpandMsg<Self::K>,
X: ExpandMsg<'dst, Self::K>,
{
let [u] = hash_to_field::<1, X, _, Self::Scalar>(msg, dst)?;
Ok(u)
Expand Down
12 changes: 8 additions & 4 deletions hash2curve/src/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,24 @@ pub trait FromOkm {
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
#[doc(hidden)]
pub fn hash_to_field<const N: usize, E, K, T>(data: &[&[u8]], domain: &[&[u8]]) -> Result<[T; N]>
pub fn hash_to_field<'dst, const N: usize, E, K, T>(
data: &[&[u8]],
domain: &'dst [&[u8]],
) -> Result<[T; N]>
where
E: ExpandMsg<K>,
E: ExpandMsg<'dst, K>,
T: FromOkm + Default,
{
let len_in_bytes = T::Length::USIZE
.checked_mul(N)
.and_then(|len| len.try_into().ok())
.and_then(NonZeroU16::new)
.ok_or(Error)?;
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
Ok(core::array::from_fn(|_| {
expander.fill_bytes(&mut tmp);
let tmp = Array::<u8, <T as FromOkm>::Length>::from_iter(
expander.by_ref().take(T::Length::USIZE),
);
T::from_okm(&tmp)
}))
}
15 changes: 3 additions & 12 deletions hash2curve/src/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,16 @@ const MAX_DST_LEN: usize = 255;
///
/// # Errors
/// See implementors of [`ExpandMsg`] for errors.
pub trait ExpandMsg<K> {
/// Type holding data for the [`Expander`].
type Expander<'dst>: Expander + Sized;

pub trait ExpandMsg<'dst, K>: Iterator<Item = u8> + Sized {
/// Expands `msg` to the required number of bytes.
///
/// Returns an expander that can be used to call `read` until enough
/// bytes have been consumed
fn expand_message<'dst>(
fn expand_message(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>>;
}

/// Expander that, call `read` until enough bytes have been consumed.
pub trait Expander {
/// Fill the array with the expanded bytes
fn fill_bytes(&mut self, okm: &mut [u8]);
) -> Result<Self>;
}

/// The domain separation tag
Expand Down
112 changes: 44 additions & 68 deletions hash2curve/src/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! `expand_message_xmd` based on a hash function.

use core::{marker::PhantomData, num::NonZero, ops::Mul};
use core::{num::NonZero, ops::Mul};

use super::{Domain, ExpandMsg, Expander};
use super::{Domain, ExpandMsg};
use digest::{
FixedOutput, HashMarker,
array::{
Expand All @@ -21,12 +21,20 @@ use elliptic_curve::{Error, Result};
/// - `dst > 255 && HashT::OutputSize > 255`
/// - `len_in_bytes > 255 * HashT::OutputSize`
#[derive(Debug)]
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
pub struct ExpandMsgXmd<'a, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>;
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
b_0: Array<u8, HashT::OutputSize>,
b_vals: Array<u8, HashT::OutputSize>,
domain: Domain<'a, HashT::OutputSize>,
index: u8,
offset: usize,
length: u16,
}

impl<HashT, K> ExpandMsg<K> for ExpandMsgXmd<HashT>
impl<'dst, HashT, K> ExpandMsg<'dst, K> for ExpandMsgXmd<'dst, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
// The number of bits output by `HashT` MUST be at most `HashT::BlockSize`.
Expand All @@ -37,23 +45,18 @@ where
K: Mul<U2>,
HashT::OutputSize: IsGreaterOrEqual<Prod<K, U2>, Output = True>,
{
type Expander<'dst> = ExpanderXmd<'dst, HashT>;

fn expand_message<'dst>(
fn expand_message(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>> {
) -> Result<Self> {
let b_in_bytes = HashT::OutputSize::USIZE;

// `255 * <b_in_bytes>` can not exceed `u16::MAX`
if usize::from(len_in_bytes.get()) > 255 * b_in_bytes {
return Err(Error);
}

let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
.expect("should never pass the previous check");

let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::default();
b_0.update(&Array::<u8, HashT::BlockSize>::default());
Expand All @@ -75,74 +78,49 @@ where
b_vals.update(&[domain.len()]);
let b_vals = b_vals.finalize_fixed();

Ok(ExpanderXmd {
Ok(Self {
b_0,
b_vals,
domain,
index: 1,
offset: 0,
ell,
length: len_in_bytes.get(),
})
}
}

/// [`Expander`] type for [`ExpandMsgXmd`].
#[derive(Debug)]
pub struct ExpanderXmd<'a, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
b_0: Array<u8, HashT::OutputSize>,
b_vals: Array<u8, HashT::OutputSize>,
domain: Domain<'a, HashT::OutputSize>,
index: u8,
offset: usize,
ell: u8,
}

impl<HashT> ExpanderXmd<'_, HashT>
impl<HashT> Iterator for ExpandMsgXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn next(&mut self) -> bool {
if self.index < self.ell {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
let mut tmp = Array::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
true
} else {
false
}
}
}
type Item = u8;

impl<HashT> Expander for ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
for b in okm {
if self.offset == self.b_vals.len() && !self.next() {
return;
}
*b = self.b_vals[self.offset];
fn next(&mut self) -> Option<u8> {
if (self.index as u16 - 1) * HashT::OutputSize::U16 + self.offset as u16 == self.length {
return None;
} else if self.offset != self.b_vals.len() {
let byte = self.b_vals[self.offset];
self.offset += 1;
return Some(byte);
}

self.index += 1;
self.offset = 1;
// b_0 XOR b_(idx - 1)
let mut tmp = Array::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
Some(self.b_vals[0])
}
}

Expand Down Expand Up @@ -210,15 +188,13 @@ mod test {
assert_message::<HashT>(self.msg, domain, L::U16, self.msg_prime);

let dst = [dst];
let mut expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
let expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
&[self.msg],
&dst,
NonZero::new(L::U16).ok_or(Error)?,
)?;

let mut uniform_bytes = Array::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);

let uniform_bytes = Array::<u8, L>::from_iter(expander);
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
Ok(())
}
Expand Down
Loading