Skip to content

Commit 771f5df

Browse files
authored
elliptic-curve: allow multiple dsts in the hash2curve API (RustCrypto#1238)
1 parent cd2ecd3 commit 771f5df

File tree

5 files changed

+115
-82
lines changed

5 files changed

+115
-82
lines changed

elliptic-curve/src/hash2curve/group_digest.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ where
4848
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
4949
fn hash_from_bytes<'a, X: ExpandMsg<'a>>(
5050
msgs: &[&[u8]],
51-
dst: &'a [u8],
51+
dsts: &'a [&'a [u8]],
5252
) -> Result<ProjectivePoint<Self>> {
5353
let mut u = [Self::FieldElement::default(), Self::FieldElement::default()];
54-
hash_to_field::<X, _>(msgs, dst, &mut u)?;
54+
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
5555
let q0 = u[0].map_to_curve();
5656
let q1 = u[1].map_to_curve();
5757
// Ideally we could add and then clear cofactor once
@@ -88,10 +88,10 @@ where
8888
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
8989
fn encode_from_bytes<'a, X: ExpandMsg<'a>>(
9090
msgs: &[&[u8]],
91-
dst: &'a [u8],
91+
dsts: &'a [&'a [u8]],
9292
) -> Result<ProjectivePoint<Self>> {
9393
let mut u = [Self::FieldElement::default()];
94-
hash_to_field::<X, _>(msgs, dst, &mut u)?;
94+
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
9595
let q0 = u[0].map_to_curve();
9696
Ok(q0.clear_cofactor().into())
9797
}
@@ -109,12 +109,15 @@ where
109109
///
110110
/// [`ExpandMsgXmd`]: crate::hash2curve::ExpandMsgXmd
111111
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
112-
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result<Self::Scalar>
112+
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(
113+
msgs: &[&[u8]],
114+
dsts: &'a [&'a [u8]],
115+
) -> Result<Self::Scalar>
113116
where
114117
Self::Scalar: FromOkm,
115118
{
116119
let mut u = [Self::Scalar::default()];
117-
hash_to_field::<X, _>(msgs, dst, &mut u)?;
120+
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
118121
Ok(u[0])
119122
}
120123
}

elliptic-curve/src/hash2curve/hash2field.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub trait FromOkm {
3232
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
3333
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
3434
#[doc(hidden)]
35-
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()>
35+
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [&'a [u8]], out: &mut [T]) -> Result<()>
3636
where
3737
E: ExpandMsg<'a>,
3838
T: FromOkm + Default,

elliptic-curve/src/hash2curve/hash2field/expand_msg.rs

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ pub trait ExpandMsg<'a> {
2525
///
2626
/// Returns an expander that can be used to call `read` until enough
2727
/// bytes have been consumed
28-
fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize)
29-
-> Result<Self::Expander>;
28+
fn expand_message(
29+
msgs: &[&[u8]],
30+
dsts: &'a [&'a [u8]],
31+
len_in_bytes: usize,
32+
) -> Result<Self::Expander>;
3033
}
3134

3235
/// Expander that, call `read` until enough bytes have been consumed.
@@ -47,54 +50,66 @@ where
4750
/// > 255
4851
Hashed(GenericArray<u8, L>),
4952
/// <= 255
50-
Array(&'a [u8]),
53+
Array(&'a [&'a [u8]]),
5154
}
5255

5356
impl<'a, L> Domain<'a, L>
5457
where
5558
L: ArrayLength<u8> + IsLess<U256>,
5659
{
57-
pub fn xof<X>(dst: &'a [u8]) -> Result<Self>
60+
pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
5861
where
5962
X: Default + ExtendableOutput + Update,
6063
{
61-
if dst.is_empty() {
64+
if dsts.is_empty() {
6265
Err(Error)
63-
} else if dst.len() > MAX_DST_LEN {
66+
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
6467
let mut data = GenericArray::<u8, L>::default();
65-
X::default()
66-
.chain(OVERSIZE_DST_SALT)
67-
.chain(dst)
68-
.finalize_xof()
69-
.read(&mut data);
68+
let mut hash = X::default();
69+
hash.update(OVERSIZE_DST_SALT);
70+
71+
for dst in dsts {
72+
hash.update(dst);
73+
}
74+
75+
hash.finalize_xof().read(&mut data);
76+
7077
Ok(Self::Hashed(data))
7178
} else {
72-
Ok(Self::Array(dst))
79+
Ok(Self::Array(dsts))
7380
}
7481
}
7582

76-
pub fn xmd<X>(dst: &'a [u8]) -> Result<Self>
83+
pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
7784
where
7885
X: Digest<OutputSize = L>,
7986
{
80-
if dst.is_empty() {
87+
if dsts.is_empty() {
8188
Err(Error)
82-
} else if dst.len() > MAX_DST_LEN {
89+
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
8390
Ok(Self::Hashed({
8491
let mut hash = X::new();
8592
hash.update(OVERSIZE_DST_SALT);
86-
hash.update(dst);
93+
94+
for dst in dsts {
95+
hash.update(dst);
96+
}
97+
8798
hash.finalize()
8899
}))
89100
} else {
90-
Ok(Self::Array(dst))
101+
Ok(Self::Array(dsts))
91102
}
92103
}
93104

94-
pub fn data(&self) -> &[u8] {
105+
pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
95106
match self {
96-
Self::Hashed(d) => &d[..],
97-
Self::Array(d) => d,
107+
Self::Hashed(d) => hash.update(d),
108+
Self::Array(d) => {
109+
for d in d.iter() {
110+
hash.update(d)
111+
}
112+
}
98113
}
99114
}
100115

@@ -103,13 +118,28 @@ where
103118
// Can't overflow because it's enforced on a type level.
104119
Self::Hashed(_) => L::to_u8(),
105120
// Can't overflow because it's checked on creation.
106-
Self::Array(d) => u8::try_from(d.len()).expect("length overflow"),
121+
Self::Array(d) => {
122+
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
123+
}
107124
}
108125
}
109126

110127
#[cfg(test)]
111128
pub fn assert(&self, bytes: &[u8]) {
112-
assert_eq!(self.data(), &bytes[..bytes.len() - 1]);
129+
let data = match self {
130+
Domain::Hashed(d) => d.to_vec(),
131+
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
132+
};
133+
assert_eq!(data, bytes);
134+
}
135+
136+
#[cfg(test)]
137+
pub fn assert_dst(&self, bytes: &[u8]) {
138+
let data = match self {
139+
Domain::Hashed(d) => d.to_vec(),
140+
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
141+
};
142+
assert_eq!(data, &bytes[..bytes.len() - 1]);
113143
assert_eq!(self.len(), bytes[bytes.len() - 1]);
114144
}
115145
}

elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use digest::{
1010
typenum::{IsLess, IsLessOrEqual, Unsigned, U256},
1111
GenericArray,
1212
},
13-
Digest,
13+
FixedOutput, HashMarker,
1414
};
1515

1616
/// Placeholder type for implementing `expand_message_xmd` based on a hash function
@@ -22,14 +22,14 @@ use digest::{
2222
/// - `len_in_bytes > 255 * HashT::OutputSize`
2323
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
2424
where
25-
HashT: Digest + BlockSizeUser,
25+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
2626
HashT::OutputSize: IsLess<U256>,
2727
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>;
2828

2929
/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait
3030
impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd<HashT>
3131
where
32-
HashT: Digest + BlockSizeUser,
32+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
3333
// If `len_in_bytes` is bigger then 256, length of the `DST` will depend on
3434
// the output size of the hash, which is still not allowed to be bigger then 256:
3535
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6
@@ -42,7 +42,7 @@ where
4242

4343
fn expand_message(
4444
msgs: &[&[u8]],
45-
dst: &'a [u8],
45+
dsts: &'a [&'a [u8]],
4646
len_in_bytes: usize,
4747
) -> Result<Self::Expander> {
4848
if len_in_bytes == 0 {
@@ -54,26 +54,26 @@ where
5454
let b_in_bytes = HashT::OutputSize::to_usize();
5555
let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?;
5656

57-
let domain = Domain::xmd::<HashT>(dst)?;
58-
let mut b_0 = HashT::new();
59-
b_0.update(GenericArray::<u8, HashT::BlockSize>::default());
57+
let domain = Domain::xmd::<HashT>(dsts)?;
58+
let mut b_0 = HashT::default();
59+
b_0.update(&GenericArray::<u8, HashT::BlockSize>::default());
6060

6161
for msg in msgs {
6262
b_0.update(msg);
6363
}
6464

65-
b_0.update(len_in_bytes_u16.to_be_bytes());
66-
b_0.update([0]);
67-
b_0.update(domain.data());
68-
b_0.update([domain.len()]);
69-
let b_0 = b_0.finalize();
65+
b_0.update(&len_in_bytes_u16.to_be_bytes());
66+
b_0.update(&[0]);
67+
domain.update_hash(&mut b_0);
68+
b_0.update(&[domain.len()]);
69+
let b_0 = b_0.finalize_fixed();
7070

71-
let mut b_vals = HashT::new();
71+
let mut b_vals = HashT::default();
7272
b_vals.update(&b_0[..]);
73-
b_vals.update([1u8]);
74-
b_vals.update(domain.data());
75-
b_vals.update([domain.len()]);
76-
let b_vals = b_vals.finalize();
73+
b_vals.update(&[1u8]);
74+
domain.update_hash(&mut b_vals);
75+
b_vals.update(&[domain.len()]);
76+
let b_vals = b_vals.finalize_fixed();
7777

7878
Ok(ExpanderXmd {
7979
b_0,
@@ -89,7 +89,7 @@ where
8989
/// [`Expander`] type for [`ExpandMsgXmd`].
9090
pub struct ExpanderXmd<'a, HashT>
9191
where
92-
HashT: Digest + BlockSizeUser,
92+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
9393
HashT::OutputSize: IsLess<U256>,
9494
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
9595
{
@@ -103,7 +103,7 @@ where
103103

104104
impl<'a, HashT> ExpanderXmd<'a, HashT>
105105
where
106-
HashT: Digest + BlockSizeUser,
106+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
107107
HashT::OutputSize: IsLess<U256>,
108108
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
109109
{
@@ -118,12 +118,12 @@ where
118118
.zip(&self.b_vals[..])
119119
.enumerate()
120120
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
121-
let mut b_vals = HashT::new();
122-
b_vals.update(tmp);
123-
b_vals.update([self.index]);
124-
b_vals.update(self.domain.data());
125-
b_vals.update([self.domain.len()]);
126-
self.b_vals = b_vals.finalize();
121+
let mut b_vals = HashT::default();
122+
b_vals.update(&tmp);
123+
b_vals.update(&[self.index]);
124+
self.domain.update_hash(&mut b_vals);
125+
b_vals.update(&[self.domain.len()]);
126+
self.b_vals = b_vals.finalize_fixed();
127127
true
128128
} else {
129129
false
@@ -133,7 +133,7 @@ where
133133

134134
impl<'a, HashT> Expander for ExpanderXmd<'a, HashT>
135135
where
136-
HashT: Digest + BlockSizeUser,
136+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
137137
HashT::OutputSize: IsLess<U256>,
138138
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
139139
{
@@ -165,7 +165,7 @@ mod test {
165165
len_in_bytes: u16,
166166
bytes: &[u8],
167167
) where
168-
HashT: Digest + BlockSizeUser,
168+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
169169
HashT::OutputSize: IsLess<U256>,
170170
{
171171
let block = HashT::BlockSize::to_usize();
@@ -183,8 +183,8 @@ mod test {
183183
let pad = l + mem::size_of::<u8>();
184184
assert_eq!([0], &bytes[l..pad]);
185185

186-
let dst = pad + domain.data().len();
187-
assert_eq!(domain.data(), &bytes[pad..dst]);
186+
let dst = pad + usize::from(domain.len());
187+
domain.assert(&bytes[pad..dst]);
188188

189189
let dst_len = dst + mem::size_of::<u8>();
190190
assert_eq!([domain.len()], &bytes[dst..dst_len]);
@@ -205,13 +205,14 @@ mod test {
205205
domain: &Domain<'_, HashT::OutputSize>,
206206
) -> Result<()>
207207
where
208-
HashT: Digest + BlockSizeUser,
208+
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
209209
HashT::OutputSize: IsLess<U256> + IsLessOrEqual<HashT::BlockSize>,
210210
{
211211
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);
212212

213+
let dst = [dst];
213214
let mut expander =
214-
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], dst, L::to_usize())?;
215+
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], &dst, L::to_usize())?;
215216

216217
let mut uniform_bytes = GenericArray::<u8, L>::default();
217218
expander.fill_bytes(&mut uniform_bytes);
@@ -227,8 +228,8 @@ mod test {
227228
const DST_PRIME: &[u8] =
228229
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826");
229230

230-
let dst_prime = Domain::xmd::<Sha256>(DST)?;
231-
dst_prime.assert(DST_PRIME);
231+
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
232+
dst_prime.assert_dst(DST_PRIME);
232233

233234
const TEST_VECTORS_32: &[TestVector] = &[
234235
TestVector {
@@ -299,8 +300,8 @@ mod test {
299300
const DST_PRIME: &[u8] =
300301
&hex!("412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620");
301302

302-
let dst_prime = Domain::xmd::<Sha256>(DST)?;
303-
dst_prime.assert(DST_PRIME);
303+
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
304+
dst_prime.assert_dst(DST_PRIME);
304305

305306
const TEST_VECTORS_32: &[TestVector] = &[
306307
TestVector {
@@ -377,8 +378,8 @@ mod test {
377378
const DST_PRIME: &[u8] =
378379
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626");
379380

380-
let dst_prime = Domain::xmd::<Sha512>(DST)?;
381-
dst_prime.assert(DST_PRIME);
381+
let dst_prime = Domain::xmd::<Sha512>(&[DST])?;
382+
dst_prime.assert_dst(DST_PRIME);
382383

383384
const TEST_VECTORS_32: &[TestVector] = &[
384385
TestVector {

0 commit comments

Comments
 (0)