Skip to content

Commit adab831

Browse files
committed
Return Result from Expander::fill_bytes()
1 parent 24674c1 commit adab831

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed

ed448-goldilocks/src/field/element.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,15 @@ mod tests {
464464
)
465465
.unwrap();
466466
let mut data = Array::<u8, U84>::default();
467-
expander.fill_bytes(&mut data);
467+
expander.fill_bytes(&mut data).unwrap();
468468
// TODO: This should be `Curve448FieldElement`.
469469
let u0 = Ed448FieldElement::from_okm(&data).0;
470470
let mut e_u0 = *expected_u0;
471471
e_u0.reverse();
472472
let mut e_u1 = *expected_u1;
473473
e_u1.reverse();
474474
assert_eq!(u0.to_bytes(), e_u0);
475-
expander.fill_bytes(&mut data);
475+
expander.fill_bytes(&mut data).unwrap();
476476
// TODO: This should be `Curve448FieldElement`.
477477
let u1 = Ed448FieldElement::from_okm(&data).0;
478478
assert_eq!(u1.to_bytes(), e_u1);
@@ -498,14 +498,14 @@ mod tests {
498498
)
499499
.unwrap();
500500
let mut data = Array::<u8, U84>::default();
501-
expander.fill_bytes(&mut data);
501+
expander.fill_bytes(&mut data).unwrap();
502502
let u0 = Ed448FieldElement::from_okm(&data).0;
503503
let mut e_u0 = *expected_u0;
504504
e_u0.reverse();
505505
let mut e_u1 = *expected_u1;
506506
e_u1.reverse();
507507
assert_eq!(u0.to_bytes(), e_u0);
508-
expander.fill_bytes(&mut data);
508+
expander.fill_bytes(&mut data).unwrap();
509509
let u1 = Ed448FieldElement::from_okm(&data).0;
510510
assert_eq!(u1.to_bytes(), e_u1);
511511
}

hash2curve/src/hash2field.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ where
5151
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
5252
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
5353
Ok(core::array::from_fn(|_| {
54-
expander.fill_bytes(&mut tmp);
54+
expander
55+
.fill_bytes(&mut tmp)
56+
.expect("never exceeds `len_in_bytes`");
5557
T::from_okm(&tmp)
5658
}))
5759
}

hash2curve/src/hash2field/expand_msg.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ pub trait ExpandMsg<K> {
3939

4040
/// Expander that, call `read` until enough bytes have been consumed.
4141
pub trait Expander {
42-
/// Fill the array with the expanded bytes
43-
fn fill_bytes(&mut self, okm: &mut [u8]);
42+
/// Fill the array with the expanded bytes, returning how many bytes were read.
43+
///
44+
/// # Errors
45+
///
46+
/// If no bytes are left.
47+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize>;
4448
}
4549

4650
/// The domain separation tag

hash2curve/src/hash2field/expand_msg/xmd.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,16 @@ where
108108
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
109109
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
110110
{
111-
fn fill_bytes(&mut self, okm: &mut [u8]) {
111+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize> {
112+
let mut read_bytes = 0;
113+
112114
for b in okm {
113115
if self.remaining == 0 {
114-
return;
116+
if read_bytes == 0 {
117+
return Err(Error);
118+
} else {
119+
return Ok(read_bytes);
120+
}
115121
}
116122

117123
if self.offset == self.b_vals.len() {
@@ -135,7 +141,10 @@ where
135141
*b = self.b_vals[self.offset];
136142
self.offset += 1;
137143
self.remaining -= 1;
144+
read_bytes += 1;
138145
}
146+
147+
Ok(read_bytes)
139148
}
140149
}
141150

@@ -210,7 +219,7 @@ mod test {
210219
)?;
211220

212221
let mut uniform_bytes = Array::<u8, L>::default();
213-
expander.fill_bytes(&mut uniform_bytes);
222+
expander.fill_bytes(&mut uniform_bytes).unwrap();
214223

215224
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
216225
Ok(())

hash2curve/src/hash2field/expand_msg/xof.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ use core::{fmt, num::NonZero, ops::Mul};
55
use digest::{
66
CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual,
77
};
8-
use elliptic_curve::Result;
98
use elliptic_curve::array::{
109
ArraySize,
1110
typenum::{Prod, True, U2},
1211
};
12+
use elliptic_curve::{Error, Result};
1313

1414
/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
1515
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xof>
@@ -78,14 +78,15 @@ impl<HashT> Expander for ExpandMsgXof<HashT>
7878
where
7979
HashT: Default + ExtendableOutput + Update + HashMarker,
8080
{
81-
fn fill_bytes(&mut self, okm: &mut [u8]) {
81+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize> {
8282
if self.remaining == 0 {
83-
return;
83+
return Err(Error);
8484
}
8585

8686
let bytes_to_read = self.remaining.min(okm.len().try_into().unwrap_or(u16::MAX));
8787
self.reader.read(&mut okm[..bytes_to_read.into()]);
8888
self.remaining -= bytes_to_read;
89+
Ok(bytes_to_read.into())
8990
}
9091
}
9192

@@ -147,7 +148,7 @@ mod test {
147148
)?;
148149

149150
let mut uniform_bytes = Array::<u8, L>::default();
150-
expander.fill_bytes(&mut uniform_bytes);
151+
expander.fill_bytes(&mut uniform_bytes).unwrap();
151152

152153
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
153154
Ok(())

0 commit comments

Comments
 (0)