11//! `expand_message_xmd` based on a hash function.
22
3- use core:: { marker :: PhantomData , num:: NonZero , ops:: Mul } ;
3+ use core:: { num:: NonZero , ops:: Mul } ;
44
5- use super :: { Domain , ExpandMsg , Expander } ;
5+ use super :: { Domain , ExpandMsg } ;
66use digest:: {
77 FixedOutput , HashMarker ,
88 array:: {
@@ -21,12 +21,20 @@ use elliptic_curve::{Error, Result};
2121/// - `dst > 255 && HashT::OutputSize > 255`
2222/// - `len_in_bytes > 255 * HashT::OutputSize`
2323#[ derive( Debug ) ]
24- pub struct ExpandMsgXmd < HashT > ( PhantomData < HashT > )
24+ pub struct ExpandMsgXmd < ' a , HashT >
2525where
2626 HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
27- HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ;
27+ HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
28+ {
29+ b_0 : Array < u8 , HashT :: OutputSize > ,
30+ b_vals : Array < u8 , HashT :: OutputSize > ,
31+ domain : Domain < ' a , HashT :: OutputSize > ,
32+ index : u8 ,
33+ offset : usize ,
34+ length : u16 ,
35+ }
2836
29- impl < HashT , K > ExpandMsg < K > for ExpandMsgXmd < HashT >
37+ impl < ' dst , HashT , K > ExpandMsg < ' dst , K > for ExpandMsgXmd < ' dst , HashT >
3038where
3139 HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
3240 // The number of bits output by `HashT` MUST be at most `HashT::BlockSize`.
@@ -37,23 +45,18 @@ where
3745 K : Mul < U2 > ,
3846 HashT :: OutputSize : IsGreaterOrEqual < Prod < K , U2 > , Output = True > ,
3947{
40- type Expander < ' dst > = ExpanderXmd < ' dst , HashT > ;
41-
42- fn expand_message < ' dst > (
48+ fn expand_message (
4349 msg : & [ & [ u8 ] ] ,
4450 dst : & ' dst [ & [ u8 ] ] ,
4551 len_in_bytes : NonZero < u16 > ,
46- ) -> Result < Self :: Expander < ' dst > > {
52+ ) -> Result < Self > {
4753 let b_in_bytes = HashT :: OutputSize :: USIZE ;
4854
4955 // `255 * <b_in_bytes>` can not exceed `u16::MAX`
5056 if usize:: from ( len_in_bytes. get ( ) ) > 255 * b_in_bytes {
5157 return Err ( Error ) ;
5258 }
5359
54- let ell = u8:: try_from ( usize:: from ( len_in_bytes. get ( ) ) . div_ceil ( b_in_bytes) )
55- . expect ( "should never pass the previous check" ) ;
56-
5760 let domain = Domain :: xmd :: < HashT > ( dst) ?;
5861 let mut b_0 = HashT :: default ( ) ;
5962 b_0. update ( & Array :: < u8 , HashT :: BlockSize > :: default ( ) ) ;
@@ -75,74 +78,51 @@ where
7578 b_vals. update ( & [ domain. len ( ) ] ) ;
7679 let b_vals = b_vals. finalize_fixed ( ) ;
7780
78- Ok ( ExpanderXmd {
81+ Ok ( Self {
7982 b_0,
8083 b_vals,
8184 domain,
8285 index : 1 ,
8386 offset : 0 ,
84- ell ,
87+ length : len_in_bytes . get ( ) ,
8588 } )
8689 }
8790}
8891
89- /// [`Expander`] type for [`ExpandMsgXmd`].
90- #[ derive( Debug ) ]
91- pub struct ExpanderXmd < ' a , HashT >
92+ impl < HashT > Iterator for ExpandMsgXmd < ' _ , HashT >
9293where
9394 HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
9495 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
9596{
96- b_0 : Array < u8 , HashT :: OutputSize > ,
97- b_vals : Array < u8 , HashT :: OutputSize > ,
98- domain : Domain < ' a , HashT :: OutputSize > ,
99- index : u8 ,
100- offset : usize ,
101- ell : u8 ,
102- }
97+ type Item = u8 ;
10398
104- impl < HashT > ExpanderXmd < ' _ , HashT >
105- where
106- HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
107- HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
108- {
109- fn next ( & mut self ) -> bool {
110- if self . index < self . ell {
111- self . index += 1 ;
112- self . offset = 0 ;
113- // b_0 XOR b_(idx - 1)
114- let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
115- self . b_0
116- . iter ( )
117- . zip ( & self . b_vals [ ..] )
118- . enumerate ( )
119- . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
120- let mut b_vals = HashT :: default ( ) ;
121- b_vals. update ( & tmp) ;
122- b_vals. update ( & [ self . index ] ) ;
123- self . domain . update_hash ( & mut b_vals) ;
124- b_vals. update ( & [ self . domain . len ( ) ] ) ;
125- self . b_vals = b_vals. finalize_fixed ( ) ;
126- true
127- } else {
128- false
129- }
130- }
131- }
132-
133- impl < HashT > Expander for ExpanderXmd < ' _ , HashT >
134- where
135- HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
136- HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
137- {
138- fn fill_bytes ( & mut self , okm : & mut [ u8 ] ) {
139- for b in okm {
140- if self . offset == self . b_vals . len ( ) && !self . next ( ) {
141- return ;
142- }
143- * b = self . b_vals [ self . offset ] ;
99+ fn next ( & mut self ) -> Option < u8 > {
100+ if ( self . index as u16 - 1 ) * HashT :: OutputSize :: U16 + self . offset as u16
101+ == self . length
102+ {
103+ return None ;
104+ } else if self . offset != self . b_vals . len ( ) {
105+ let byte = self . b_vals [ self . offset ] ;
144106 self . offset += 1 ;
107+ return Some ( byte) ;
145108 }
109+
110+ self . index += 1 ;
111+ self . offset = 1 ;
112+ // b_0 XOR b_(idx - 1)
113+ let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
114+ self . b_0
115+ . iter ( )
116+ . zip ( & self . b_vals [ ..] )
117+ . enumerate ( )
118+ . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
119+ let mut b_vals = HashT :: default ( ) ;
120+ b_vals. update ( & tmp) ;
121+ b_vals. update ( & [ self . index ] ) ;
122+ self . domain . update_hash ( & mut b_vals) ;
123+ b_vals. update ( & [ self . domain . len ( ) ] ) ;
124+ self . b_vals = b_vals. finalize_fixed ( ) ;
125+ Some ( self . b_vals [ 0 ] )
146126 }
147127}
148128
@@ -210,15 +190,13 @@ mod test {
210190 assert_message :: < HashT > ( self . msg , domain, L :: U16 , self . msg_prime ) ;
211191
212192 let dst = [ dst] ;
213- let mut expander = <ExpandMsgXmd < HashT > as ExpandMsg < U4 > >:: expand_message (
193+ let expander = <ExpandMsgXmd < HashT > as ExpandMsg < U4 > >:: expand_message (
214194 & [ self . msg ] ,
215195 & dst,
216196 NonZero :: new ( L :: U16 ) . ok_or ( Error ) ?,
217197 ) ?;
218198
219- let mut uniform_bytes = Array :: < u8 , L > :: default ( ) ;
220- expander. fill_bytes ( & mut uniform_bytes) ;
221-
199+ let uniform_bytes = Array :: < u8 , L > :: from_iter ( expander) ;
222200 assert_eq ! ( uniform_bytes. as_slice( ) , self . uniform_bytes) ;
223201 Ok ( ( ) )
224202 }
0 commit comments