diff --git a/ed448-goldilocks/src/field/element.rs b/ed448-goldilocks/src/field/element.rs index 2a1d9d59b..fcbf6c59d 100644 --- a/ed448-goldilocks/src/field/element.rs +++ b/ed448-goldilocks/src/field/element.rs @@ -441,7 +441,7 @@ impl FieldElement { mod tests { use super::*; use elliptic_curve::consts::U32; - use hash2curve::{ExpandMsg, ExpandMsgXof, Expander}; + use hash2curve::{ExpandMsg, ExpandMsgXof}; use hex_literal::hex; use sha3::Shake256; @@ -463,18 +463,16 @@ mod tests { (84 * 2).try_into().unwrap(), ) .unwrap(); - let mut data = Array::::default(); - expander.fill_bytes(&mut data); // TODO: This should be `Curve448FieldElement`. - let u0 = Ed448FieldElement::from_okm(&data).0; + let u0 = Ed448FieldElement::from_okm(&expander.by_ref().take(84).collect()).0; let mut e_u0 = *expected_u0; e_u0.reverse(); let mut e_u1 = *expected_u1; e_u1.reverse(); assert_eq!(u0.to_bytes(), e_u0); - expander.fill_bytes(&mut data); + // TODO: This should be `Curve448FieldElement`. - let u1 = Ed448FieldElement::from_okm(&data).0; + let u1 = Ed448FieldElement::from_okm(&expander.collect()).0; assert_eq!(u1.to_bytes(), e_u1); } } @@ -497,16 +495,13 @@ mod tests { (84 * 2).try_into().unwrap(), ) .unwrap(); - let mut data = Array::::default(); - expander.fill_bytes(&mut data); - let u0 = Ed448FieldElement::from_okm(&data).0; + let u0 = Ed448FieldElement::from_okm(&expander.by_ref().take(84).collect()).0; let mut e_u0 = *expected_u0; e_u0.reverse(); let mut e_u1 = *expected_u1; e_u1.reverse(); assert_eq!(u0.to_bytes(), e_u0); - expander.fill_bytes(&mut data); - let u1 = Ed448FieldElement::from_okm(&data).0; + let u1 = Ed448FieldElement::from_okm(&expander.collect()).0; assert_eq!(u1.to_bytes(), e_u1); } } diff --git a/hash2curve/src/hash2field.rs b/hash2curve/src/hash2field.rs index 267cdefc6..12094c37c 100644 --- a/hash2curve/src/hash2field.rs +++ b/hash2curve/src/hash2field.rs @@ -48,10 +48,8 @@ where .and_then(|len| len.try_into().ok()) .and_then(NonZeroU16::new) .ok_or(Error)?; - let mut tmp = Array::::Length>::default(); let mut expander = E::expand_message(data, domain, len_in_bytes)?; Ok(core::array::from_fn(|_| { - expander.fill_bytes(&mut tmp); - T::from_okm(&tmp) + T::from_okm(&expander.by_ref().take(T::Length::USIZE).collect()) })) } diff --git a/hash2curve/src/hash2field/expand_msg.rs b/hash2curve/src/hash2field/expand_msg.rs index 5db42b73a..f070b3a82 100644 --- a/hash2curve/src/hash2field/expand_msg.rs +++ b/hash2curve/src/hash2field/expand_msg.rs @@ -23,8 +23,8 @@ const MAX_DST_LEN: usize = 255; /// # Errors /// See implementors of [`ExpandMsg`] for errors. pub trait ExpandMsg { - /// Type holding data for the [`Expander`]. - type Expander<'dst>: Expander + Sized; + /// The expanded message. + type Expanded<'a>: Iterator; /// Expands `msg` to the required number of bytes. /// @@ -34,13 +34,7 @@ pub trait ExpandMsg { msg: &[&[u8]], dst: &'dst [&[u8]], len_in_bytes: NonZero, - ) -> Result>; -} - -/// Expander that, call `read` until enough bytes have been consumed. -pub trait Expander { - /// Fill the array with the expanded bytes - fn fill_bytes(&mut self, okm: &mut [u8]); + ) -> Result>; } /// The domain separation tag diff --git a/hash2curve/src/hash2field/expand_msg/xmd.rs b/hash2curve/src/hash2field/expand_msg/xmd.rs index c464700d2..4491fd444 100644 --- a/hash2curve/src/hash2field/expand_msg/xmd.rs +++ b/hash2curve/src/hash2field/expand_msg/xmd.rs @@ -1,8 +1,8 @@ //! `expand_message_xmd` based on a hash function. -use core::{marker::PhantomData, num::NonZero, ops::Mul}; +use core::{num::NonZero, ops::Mul}; -use super::{Domain, ExpandMsg, Expander}; +use super::{Domain, ExpandMsg}; use digest::{ FixedOutput, HashMarker, array::{ @@ -20,11 +20,8 @@ use elliptic_curve::{Error, Result}; /// - `dst` contains no bytes /// - `dst > 255 && HashT::OutputSize > 255` /// - `len_in_bytes > 255 * HashT::OutputSize` -#[derive(Debug)] -pub struct ExpandMsgXmd(PhantomData) -where - HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - HashT::OutputSize: IsLessOrEqual; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ExpandMsgXmd(core::marker::PhantomData); impl ExpandMsg for ExpandMsgXmd where @@ -37,13 +34,13 @@ where K: Mul, HashT::OutputSize: IsGreaterOrEqual, Output = True>, { - type Expander<'dst> = ExpanderXmd<'dst, HashT>; + type Expanded<'a> = ExpandedXmd<'a, HashT>; - fn expand_message<'dst>( + fn expand_message<'a>( msg: &[&[u8]], - dst: &'dst [&[u8]], + dst: &'a [&[u8]], len_in_bytes: NonZero, - ) -> Result> { + ) -> Result> { let b_in_bytes = HashT::OutputSize::USIZE; // `255 * ` can not exceed `u16::MAX` @@ -51,9 +48,6 @@ where return Err(Error); } - let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes)) - .expect("should never pass the previous check"); - let domain = Domain::xmd::(dst)?; let mut b_0 = HashT::default(); b_0.update(&Array::::default()); @@ -75,20 +69,20 @@ where b_vals.update(&[domain.len()]); let b_vals = b_vals.finalize_fixed(); - Ok(ExpanderXmd { + Ok(ExpandedXmd { b_0, b_vals, domain, index: 1, offset: 0, - ell, + remaining: len_in_bytes.get(), }) } } -/// [`Expander`] type for [`ExpandMsgXmd`]. +/// The expanded bytes of `expand_message_xmd`. #[derive(Debug)] -pub struct ExpanderXmd<'a, HashT> +pub struct ExpandedXmd<'a, HashT> where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLessOrEqual, @@ -98,16 +92,22 @@ where domain: Domain<'a, HashT::OutputSize>, index: u8, offset: usize, - ell: u8, + remaining: u16, } -impl ExpanderXmd<'_, HashT> +impl Iterator for ExpandedXmd<'_, HashT> where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLessOrEqual, { - fn next(&mut self) -> bool { - if self.index < self.ell { + type Item = u8; + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + if self.offset == self.b_vals.len() { self.index += 1; self.offset = 0; // b_0 XOR b_(idx - 1) @@ -123,26 +123,12 @@ where self.domain.update_hash(&mut b_vals); b_vals.update(&[self.domain.len()]); self.b_vals = b_vals.finalize_fixed(); - true - } else { - false } - } -} -impl Expander for ExpanderXmd<'_, HashT> -where - HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - HashT::OutputSize: IsLessOrEqual, -{ - fn fill_bytes(&mut self, okm: &mut [u8]) { - for b in okm { - if self.offset == self.b_vals.len() && !self.next() { - return; - } - *b = self.b_vals[self.offset]; - self.offset += 1; - } + let byte = self.b_vals[self.offset]; + self.offset += 1; + self.remaining -= 1; + Some(byte) } } @@ -210,15 +196,13 @@ mod test { assert_message::(self.msg, domain, L::U16, self.msg_prime); let dst = [dst]; - let mut expander = as ExpandMsg>::expand_message( + let expander = as ExpandMsg>::expand_message( &[self.msg], &dst, NonZero::new(L::U16).ok_or(Error)?, )?; - let mut uniform_bytes = Array::::default(); - expander.fill_bytes(&mut uniform_bytes); - + let uniform_bytes: Array = expander.collect(); assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes); Ok(()) } diff --git a/hash2curve/src/hash2field/expand_msg/xof.rs b/hash2curve/src/hash2field/expand_msg/xof.rs index 32d00a7c1..5745a2b4a 100644 --- a/hash2curve/src/hash2field/expand_msg/xof.rs +++ b/hash2curve/src/hash2field/expand_msg/xof.rs @@ -1,9 +1,10 @@ //! `expand_message_xof` for the `ExpandMsg` trait -use super::{Domain, ExpandMsg, Expander}; -use core::{fmt, num::NonZero, ops::Mul}; +use super::{Domain, ExpandMsg}; +use core::{array, fmt, num::NonZero, ops::Mul}; +use digest::XofReader; use digest::{ - CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual, + CollisionResistance, ExtendableOutput, HashMarker, Update, typenum::IsGreaterOrEqual, }; use elliptic_curve::Result; use elliptic_curve::array::{ @@ -21,7 +22,8 @@ pub struct ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, { - reader: ::Reader, + reader: HashT::Reader, + length: u16, } impl fmt::Debug for ExpandMsgXof @@ -46,13 +48,9 @@ where // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.2-2.1 HashT: CollisionResistance>, { - type Expander<'dst> = Self; + type Expanded<'a> = Self; - fn expand_message<'dst>( - msg: &[&[u8]], - dst: &'dst [&[u8]], - len_in_bytes: NonZero, - ) -> Result> { + fn expand_message(msg: &[&[u8]], dst: &[&[u8]], len_in_bytes: NonZero) -> Result { let len_in_bytes = len_in_bytes.get(); let domain = Domain::>::xof::(dst)?; @@ -66,16 +64,27 @@ where domain.update_hash(&mut reader); reader.update(&[domain.len()]); let reader = reader.finalize_xof(); - Ok(Self { reader }) + Ok(Self { + reader, + length: len_in_bytes, + }) } } -impl Expander for ExpandMsgXof +impl Iterator for ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, { - fn fill_bytes(&mut self, okm: &mut [u8]) { - self.reader.read(okm); + type Item = u8; + + fn next(&mut self) -> Option { + if self.length == 0 { + return None; + } + self.length -= 1; + let mut byte = 0; + self.reader.read(array::from_mut(&mut byte)); + Some(byte) } } @@ -130,15 +139,13 @@ mod test { { assert_message(self.msg, domain, L::to_u16(), self.msg_prime); - let mut expander = as ExpandMsg>::expand_message( + let expander = as ExpandMsg>::expand_message( &[self.msg], &[dst], NonZero::new(L::U16).ok_or(Error)?, )?; - let mut uniform_bytes = Array::::default(); - expander.fill_bytes(&mut uniform_bytes); - + let uniform_bytes: Array = expander.collect(); assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes); Ok(()) }