50
50
return Err ( ExpandMsgXmdError :: Length ) ;
51
51
}
52
52
53
- let ell = u8:: try_from ( usize:: from ( len_in_bytes. get ( ) ) . div_ceil ( b_in_bytes) )
54
- . expect ( "should never pass the previous check" ) ;
53
+ debug_assert ! (
54
+ usize :: from( len_in_bytes. get( ) ) . div_ceil( b_in_bytes) <= u8 :: MAX . into( ) ,
55
+ "should never pass the previous check"
56
+ ) ;
55
57
56
58
let domain = Domain :: xmd :: < HashT > ( dst) ?;
57
59
let mut b_0 = HashT :: default ( ) ;
80
82
domain,
81
83
index : 1 ,
82
84
offset : 0 ,
83
- ell ,
85
+ remaining : len_in_bytes . get ( ) ,
84
86
} )
85
87
}
86
88
}
97
99
domain : Domain < ' a , HashT :: OutputSize > ,
98
100
index : u8 ,
99
101
offset : usize ,
100
- ell : u8 ,
101
- }
102
-
103
- impl < HashT > ExpanderXmd < ' _ , HashT >
104
- where
105
- HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
106
- HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
107
- {
108
- fn next ( & mut self ) -> bool {
109
- if self . index < self . ell {
110
- self . index += 1 ;
111
- self . offset = 0 ;
112
- // b_0 XOR b_(idx - 1)
113
- let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
114
- self . b_0
115
- . iter ( )
116
- . zip ( & self . b_vals [ ..] )
117
- . enumerate ( )
118
- . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
119
- let mut b_vals = HashT :: default ( ) ;
120
- b_vals. update ( & tmp) ;
121
- b_vals. update ( & [ self . index ] ) ;
122
- self . domain . update_hash ( & mut b_vals) ;
123
- b_vals. update ( & [ self . domain . len ( ) ] ) ;
124
- self . b_vals = b_vals. finalize_fixed ( ) ;
125
- true
126
- } else {
127
- false
128
- }
129
- }
102
+ remaining : u16 ,
130
103
}
131
104
132
105
impl < HashT > Expander for ExpanderXmd < ' _ , HashT >
@@ -136,11 +109,31 @@ where
136
109
{
137
110
fn fill_bytes ( & mut self , okm : & mut [ u8 ] ) {
138
111
for b in okm {
139
- if self . offset == self . b_vals . len ( ) && ! self . next ( ) {
112
+ if self . remaining == 0 {
140
113
return ;
141
114
}
115
+
116
+ if self . offset == self . b_vals . len ( ) {
117
+ self . index += 1 ;
118
+ self . offset = 0 ;
119
+ // b_0 XOR b_(idx - 1)
120
+ let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
121
+ self . b_0
122
+ . iter ( )
123
+ . zip ( & self . b_vals [ ..] )
124
+ . enumerate ( )
125
+ . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
126
+ let mut b_vals = HashT :: default ( ) ;
127
+ b_vals. update ( & tmp) ;
128
+ b_vals. update ( & [ self . index ] ) ;
129
+ self . domain . update_hash ( & mut b_vals) ;
130
+ b_vals. update ( & [ self . domain . len ( ) ] ) ;
131
+ self . b_vals = b_vals. finalize_fixed ( ) ;
132
+ }
133
+
142
134
* b = self . b_vals [ self . offset ] ;
143
135
self . offset += 1 ;
136
+ self . remaining -= 1 ;
144
137
}
145
138
}
146
139
}
0 commit comments