Skip to content

Commit f541a9a

Browse files
committed
Limit ExpandMsg output by len_in_bytes
1 parent e797acd commit f541a9a

File tree

2 files changed

+39
-36
lines changed

2 files changed

+39
-36
lines changed

hash2curve/src/hash2field/expand_msg/xmd.rs

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ where
5050
return Err(ExpandMsgXmdError::Length);
5151
}
5252

53-
let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
54-
.expect("should never pass the previous check");
53+
debug_assert!(
54+
usize::from(len_in_bytes.get()).div_ceil(b_in_bytes) <= u8::MAX.into(),
55+
"should never pass the previous check"
56+
);
5557

5658
let domain = Domain::xmd::<HashT>(dst)?;
5759
let mut b_0 = HashT::default();
@@ -80,7 +82,7 @@ where
8082
domain,
8183
index: 1,
8284
offset: 0,
83-
ell,
85+
remaining: len_in_bytes.get(),
8486
})
8587
}
8688
}
@@ -97,36 +99,7 @@ where
9799
domain: Domain<'a, HashT::OutputSize>,
98100
index: u8,
99101
offset: usize,
100-
ell: u8,
101-
}
102-
103-
impl<HashT> ExpanderXmd<'_, HashT>
104-
where
105-
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
106-
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
107-
{
108-
fn next(&mut self) -> bool {
109-
if self.index < self.ell {
110-
self.index += 1;
111-
self.offset = 0;
112-
// b_0 XOR b_(idx - 1)
113-
let mut tmp = Array::<u8, HashT::OutputSize>::default();
114-
self.b_0
115-
.iter()
116-
.zip(&self.b_vals[..])
117-
.enumerate()
118-
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
119-
let mut b_vals = HashT::default();
120-
b_vals.update(&tmp);
121-
b_vals.update(&[self.index]);
122-
self.domain.update_hash(&mut b_vals);
123-
b_vals.update(&[self.domain.len()]);
124-
self.b_vals = b_vals.finalize_fixed();
125-
true
126-
} else {
127-
false
128-
}
129-
}
102+
remaining: u16,
130103
}
131104

132105
impl<HashT> Expander for ExpanderXmd<'_, HashT>
@@ -136,11 +109,31 @@ where
136109
{
137110
fn fill_bytes(&mut self, okm: &mut [u8]) {
138111
for b in okm {
139-
if self.offset == self.b_vals.len() && !self.next() {
112+
if self.remaining == 0 {
140113
return;
141114
}
115+
116+
if self.offset == self.b_vals.len() {
117+
self.index += 1;
118+
self.offset = 0;
119+
// b_0 XOR b_(idx - 1)
120+
let mut tmp = Array::<u8, HashT::OutputSize>::default();
121+
self.b_0
122+
.iter()
123+
.zip(&self.b_vals[..])
124+
.enumerate()
125+
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
126+
let mut b_vals = HashT::default();
127+
b_vals.update(&tmp);
128+
b_vals.update(&[self.index]);
129+
self.domain.update_hash(&mut b_vals);
130+
b_vals.update(&[self.domain.len()]);
131+
self.b_vals = b_vals.finalize_fixed();
132+
}
133+
142134
*b = self.b_vals[self.offset];
143135
self.offset += 1;
136+
self.remaining -= 1;
144137
}
145138
}
146139
}

hash2curve/src/hash2field/expand_msg/xof.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ where
1919
HashT: Default + ExtendableOutput + Update + HashMarker,
2020
{
2121
reader: <HashT as ExtendableOutput>::Reader,
22+
remaining: u16,
2223
}
2324

2425
impl<HashT> fmt::Debug for ExpandMsgXof<HashT>
@@ -64,7 +65,10 @@ where
6465
domain.update_hash(&mut reader);
6566
reader.update(&[domain.len()]);
6667
let reader = reader.finalize_xof();
67-
Ok(Self { reader })
68+
Ok(Self {
69+
reader,
70+
remaining: len_in_bytes,
71+
})
6872
}
6973
}
7074

@@ -73,7 +77,13 @@ where
7377
HashT: Default + ExtendableOutput + Update + HashMarker,
7478
{
7579
fn fill_bytes(&mut self, okm: &mut [u8]) {
76-
self.reader.read(okm);
80+
if self.remaining == 0 {
81+
return;
82+
}
83+
84+
let bytes_to_read = self.remaining.min(okm.len().try_into().unwrap_or(u16::MAX));
85+
self.reader.read(&mut okm[..bytes_to_read.into()]);
86+
self.remaining -= bytes_to_read;
7787
}
7888
}
7989

0 commit comments

Comments
 (0)