Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ed448-goldilocks/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,15 @@ mod tests {
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
expander.fill_bytes(&mut data).unwrap();
// TODO: This should be `Curve448FieldElement`.
let u0 = Ed448FieldElement::from_okm(&data).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);
expander.fill_bytes(&mut data).unwrap();
// TODO: This should be `Curve448FieldElement`.
let u1 = Ed448FieldElement::from_okm(&data).0;
assert_eq!(u1.to_bytes(), e_u1);
Expand All @@ -570,14 +570,14 @@ mod tests {
)
.unwrap();
let mut data = Array::<u8, U84>::default();
expander.fill_bytes(&mut data);
expander.fill_bytes(&mut data).unwrap();
let u0 = Ed448FieldElement::from_okm(&data).0;
let mut e_u0 = *expected_u0;
e_u0.reverse();
let mut e_u1 = *expected_u1;
e_u1.reverse();
assert_eq!(u0.to_bytes(), e_u0);
expander.fill_bytes(&mut data);
expander.fill_bytes(&mut data).unwrap();
let u1 = Ed448FieldElement::from_okm(&data).0;
assert_eq!(u1.to_bytes(), e_u1);
}
Expand Down
4 changes: 3 additions & 1 deletion hash2curve/src/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ where
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
Ok(core::array::from_fn(|_| {
expander.fill_bytes(&mut tmp);
expander
.fill_bytes(&mut tmp)
.expect("never exceeds `len_in_bytes`");
T::from_okm(&tmp)
}))
}
9 changes: 7 additions & 2 deletions hash2curve/src/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub(super) mod xof;
use core::num::NonZero;

use digest::{Digest, ExtendableOutput, Update, XofReader};
use elliptic_curve::Error;
use elliptic_curve::array::{Array, ArraySize};
use xmd::ExpandMsgXmdError;
use xof::ExpandMsgXofError;
Expand Down Expand Up @@ -42,8 +43,12 @@ pub trait ExpandMsg<K> {

/// Expander that, call `read` until enough bytes have been consumed.
pub trait Expander {
/// Fill the array with the expanded bytes
fn fill_bytes(&mut self, okm: &mut [u8]);
/// Fill the array with the expanded bytes, returning how many bytes were read.
///
/// # Errors
///
/// If no bytes are left.
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
}

/// The domain separation tag
Expand Down
119 changes: 80 additions & 39 deletions hash2curve/src/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use digest::{
},
block_api::BlockSizeUser,
};
use elliptic_curve::Error;

/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmd>
Expand Down Expand Up @@ -50,8 +51,10 @@ where
return Err(ExpandMsgXmdError::Length);
}

let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
.expect("should never pass the previous check");
debug_assert!(
usize::from(len_in_bytes.get()).div_ceil(b_in_bytes) <= u8::MAX.into(),
"should never pass the previous check"
);

let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::default();
Expand Down Expand Up @@ -80,7 +83,7 @@ where
domain,
index: 1,
offset: 0,
ell,
remaining: len_in_bytes.get(),
})
}
}
Expand All @@ -97,51 +100,64 @@ where
domain: Domain<'a, HashT::OutputSize>,
index: u8,
offset: usize,
ell: u8,
remaining: u16,
}

impl<HashT> ExpanderXmd<'_, HashT>
impl<HashT> Expander for ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn next(&mut self) -> bool {
if self.index < self.ell {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
let mut tmp = Array::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
true
} else {
false
fn fill_bytes(&mut self, mut okm: &mut [u8]) -> Result<usize, Error> {
if self.remaining == 0 {
return Err(Error);
}
}
}

impl<HashT> Expander for ExpanderXmd<'_, HashT>
where
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
for b in okm {
if self.offset == self.b_vals.len() && !self.next() {
return;
let mut read_bytes = 0;

while self.remaining != 0 {
if self.offset == self.b_vals.len() {
self.index += 1;
self.offset = 0;
// b_0 XOR b_(idx - 1)
let mut tmp = Array::<u8, HashT::OutputSize>::default();
self.b_0
.iter()
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
}

let bytes_to_read = self
.remaining
.min(okm.len().try_into().unwrap_or(u16::MAX))
.min(
(self.b_vals.len() - self.offset)
.try_into()
.unwrap_or(u16::MAX),
);

if bytes_to_read == 0 {
return Ok(read_bytes);
}
*b = self.b_vals[self.offset];
self.offset += 1;

okm[..bytes_to_read.into()].copy_from_slice(
&self.b_vals[self.offset..self.offset + usize::from(bytes_to_read)],
);
okm = &mut okm[bytes_to_read.into()..];

self.offset += usize::from(bytes_to_read);
self.remaining -= bytes_to_read;
read_bytes += usize::from(bytes_to_read);
}

Ok(read_bytes)
}
}

Expand Down Expand Up @@ -181,6 +197,31 @@ mod test {
use hex_literal::hex;
use sha2::Sha256;

#[test]
fn edge_cases() {
fn generate() -> ExpanderXmd<'static, Sha256> {
<ExpandMsgXmd<Sha256> as ExpandMsg<U4>>::expand_message(
&[b"test message"],
&[b"test DST"],
NonZero::new(64).unwrap(),
)
.unwrap()
}

assert_eq!(generate().fill_bytes(&mut [0; 0]), Ok(0));
assert_eq!(generate().fill_bytes(&mut [0; 1]), Ok(1));
assert_eq!(generate().fill_bytes(&mut [0; 64]), Ok(64));
assert_eq!(generate().fill_bytes(&mut [0; 65]), Ok(64));

let mut expander = generate();
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
assert_eq!(expander.fill_bytes(&mut [0; 32]), Ok(32));
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
assert_eq!(expander.fill_bytes(&mut [0; 31]), Ok(31));
assert_eq!(expander.fill_bytes(&mut [0; 2]), Ok(1));
assert_eq!(expander.fill_bytes(&mut [0; 1]), Err(Error));
}

fn assert_message<HashT>(
msg: &[u8],
domain: &Domain<'_, HashT::OutputSize>,
Expand Down Expand Up @@ -239,7 +280,7 @@ mod test {
.unwrap();

let mut uniform_bytes = Array::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);
expander.fill_bytes(&mut uniform_bytes).unwrap();

assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
}
Expand Down
43 changes: 39 additions & 4 deletions hash2curve/src/hash2field/expand_msg/xof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use super::{Domain, ExpandMsg, Expander};
use core::{fmt, num::NonZero, ops::Mul};
use digest::{CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader};
use elliptic_curve::Error;
use elliptic_curve::array::{
ArraySize,
typenum::{IsGreaterOrEqual, Prod, True, U2},
Expand All @@ -19,6 +20,7 @@ where
HashT: Default + ExtendableOutput + Update + HashMarker,
{
reader: <HashT as ExtendableOutput>::Reader,
remaining: u16,
}

impl<HashT> fmt::Debug for ExpandMsgXof<HashT>
Expand Down Expand Up @@ -64,16 +66,26 @@ where
domain.update_hash(&mut reader);
reader.update(&[domain.len()]);
let reader = reader.finalize_xof();
Ok(Self { reader })
Ok(Self {
reader,
remaining: len_in_bytes,
})
}
}

impl<HashT> Expander for ExpandMsgXof<HashT>
where
HashT: Default + ExtendableOutput + Update + HashMarker,
{
fn fill_bytes(&mut self, okm: &mut [u8]) {
self.reader.read(okm);
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error> {
if self.remaining == 0 {
return Err(Error);
}

let bytes_to_read = self.remaining.min(okm.len().try_into().unwrap_or(u16::MAX));
self.reader.read(&mut okm[..bytes_to_read.into()]);
self.remaining -= bytes_to_read;
Ok(bytes_to_read.into())
}
}

Expand Down Expand Up @@ -109,6 +121,29 @@ mod test {
use hex_literal::hex;
use sha3::Shake128;

#[test]
fn edge_cases() {
fn generate() -> ExpandMsgXof<Shake128> {
<ExpandMsgXof<Shake128> as ExpandMsg<U16>>::expand_message(
&[b"test message"],
&[b"test DST"],
NonZero::new(64).unwrap(),
)
.unwrap()
}

assert_eq!(generate().fill_bytes(&mut [0; 0]), Ok(0));
assert_eq!(generate().fill_bytes(&mut [0; 1]), Ok(1));
assert_eq!(generate().fill_bytes(&mut [0; 64]), Ok(64));
assert_eq!(generate().fill_bytes(&mut [0; 65]), Ok(64));

let mut expander = generate();
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
assert_eq!(expander.fill_bytes(&mut [0; 1]), Ok(1));
assert_eq!(expander.fill_bytes(&mut [0; 64]), Ok(63));
assert_eq!(expander.fill_bytes(&mut [0; 1]), Err(Error));
}

fn assert_message(msg: &[u8], domain: &Domain<'_, U32>, len_in_bytes: u16, bytes: &[u8]) {
let msg_len = msg.len();
assert_eq!(msg, &bytes[..msg_len]);
Expand Down Expand Up @@ -155,7 +190,7 @@ mod test {
.unwrap();

let mut uniform_bytes = Array::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);
expander.fill_bytes(&mut uniform_bytes).unwrap();

assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
}
Expand Down