Skip to content

Commit ad61d3c

Browse files
committed
Return Result from Expander::fill_bytes()
1 parent f541a9a commit ad61d3c

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

ed448-goldilocks/src/field/element.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,15 @@ mod tests {
465465
)
466466
.unwrap();
467467
let mut data = Array::<u8, U84>::default();
468-
expander.fill_bytes(&mut data);
468+
expander.fill_bytes(&mut data).unwrap();
469469
// TODO: This should be `Curve448FieldElement`.
470470
let u0 = Ed448FieldElement::from_okm(&data).0;
471471
let mut e_u0 = *expected_u0;
472472
e_u0.reverse();
473473
let mut e_u1 = *expected_u1;
474474
e_u1.reverse();
475475
assert_eq!(u0.to_bytes(), e_u0);
476-
expander.fill_bytes(&mut data);
476+
expander.fill_bytes(&mut data).unwrap();
477477
// TODO: This should be `Curve448FieldElement`.
478478
let u1 = Ed448FieldElement::from_okm(&data).0;
479479
assert_eq!(u1.to_bytes(), e_u1);
@@ -499,14 +499,14 @@ mod tests {
499499
)
500500
.unwrap();
501501
let mut data = Array::<u8, U84>::default();
502-
expander.fill_bytes(&mut data);
502+
expander.fill_bytes(&mut data).unwrap();
503503
let u0 = Ed448FieldElement::from_okm(&data).0;
504504
let mut e_u0 = *expected_u0;
505505
e_u0.reverse();
506506
let mut e_u1 = *expected_u1;
507507
e_u1.reverse();
508508
assert_eq!(u0.to_bytes(), e_u0);
509-
expander.fill_bytes(&mut data);
509+
expander.fill_bytes(&mut data).unwrap();
510510
let u1 = Ed448FieldElement::from_okm(&data).0;
511511
assert_eq!(u1.to_bytes(), e_u1);
512512
}

hash2curve/src/hash2field.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ where
5151
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
5252
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
5353
Ok(core::array::from_fn(|_| {
54-
expander.fill_bytes(&mut tmp);
54+
expander
55+
.fill_bytes(&mut tmp)
56+
.expect("never exceeds `len_in_bytes`");
5557
T::from_okm(&tmp)
5658
}))
5759
}

hash2curve/src/hash2field/expand_msg.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(super) mod xof;
66
use core::num::NonZero;
77

88
use digest::{Digest, ExtendableOutput, Update, XofReader};
9+
use elliptic_curve::Error;
910
use elliptic_curve::array::{Array, ArraySize};
1011
use xmd::ExpandMsgXmdError;
1112
use xof::ExpandMsgXofError;
@@ -42,8 +43,12 @@ pub trait ExpandMsg<K> {
4243

4344
/// Expander that, call `read` until enough bytes have been consumed.
4445
pub trait Expander {
45-
/// Fill the array with the expanded bytes
46-
fn fill_bytes(&mut self, okm: &mut [u8]);
46+
/// Fill the array with the expanded bytes, returning how many bytes were read.
47+
///
48+
/// # Errors
49+
///
50+
/// If no bytes are left.
51+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
4752
}
4853

4954
/// The domain separation tag

hash2curve/src/hash2field/expand_msg/xmd.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use digest::{
1111
},
1212
block_api::BlockSizeUser,
1313
};
14+
use elliptic_curve::Error;
1415

1516
/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
1617
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmd>
@@ -107,10 +108,16 @@ where
107108
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
108109
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
109110
{
110-
fn fill_bytes(&mut self, okm: &mut [u8]) {
111+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error> {
112+
let mut read_bytes = 0;
113+
111114
for b in okm {
112115
if self.remaining == 0 {
113-
return;
116+
if read_bytes == 0 {
117+
return Err(Error);
118+
} else {
119+
return Ok(read_bytes);
120+
}
114121
}
115122

116123
if self.offset == self.b_vals.len() {
@@ -134,7 +141,10 @@ where
134141
*b = self.b_vals[self.offset];
135142
self.offset += 1;
136143
self.remaining -= 1;
144+
read_bytes += 1;
137145
}
146+
147+
Ok(read_bytes)
138148
}
139149
}
140150

@@ -232,7 +242,7 @@ mod test {
232242
.unwrap();
233243

234244
let mut uniform_bytes = Array::<u8, L>::default();
235-
expander.fill_bytes(&mut uniform_bytes);
245+
expander.fill_bytes(&mut uniform_bytes).unwrap();
236246

237247
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
238248
}

hash2curve/src/hash2field/expand_msg/xof.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use super::{Domain, ExpandMsg, Expander};
44
use core::{fmt, num::NonZero, ops::Mul};
55
use digest::{CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader};
6+
use elliptic_curve::Error;
67
use elliptic_curve::array::{
78
ArraySize,
89
typenum::{IsGreaterOrEqual, Prod, True, U2},
@@ -76,14 +77,15 @@ impl<HashT> Expander for ExpandMsgXof<HashT>
7677
where
7778
HashT: Default + ExtendableOutput + Update + HashMarker,
7879
{
79-
fn fill_bytes(&mut self, okm: &mut [u8]) {
80+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error> {
8081
if self.remaining == 0 {
81-
return;
82+
return Err(Error);
8283
}
8384

8485
let bytes_to_read = self.remaining.min(okm.len().try_into().unwrap_or(u16::MAX));
8586
self.reader.read(&mut okm[..bytes_to_read.into()]);
8687
self.remaining -= bytes_to_read;
88+
Ok(bytes_to_read.into())
8789
}
8890
}
8991

@@ -165,7 +167,7 @@ mod test {
165167
.unwrap();
166168

167169
let mut uniform_bytes = Array::<u8, L>::default();
168-
expander.fill_bytes(&mut uniform_bytes);
170+
expander.fill_bytes(&mut uniform_bytes).unwrap();
169171

170172
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
171173
}

0 commit comments

Comments
 (0)