4
4
//! Adapted from the Crypto++ `chacha_simd` implementation by Jack Lloyd and
5
5
//! Jeffrey Walton (public domain).
6
6
7
- use crate :: { Rounds , STATE_WORDS } ;
7
+ use crate :: { Rounds , STATE_WORDS , Variant } ;
8
8
use core:: { arch:: aarch64:: * , marker:: PhantomData } ;
9
9
10
10
#[ cfg( feature = "rand_core" ) ]
11
- use crate :: { ChaChaCore , Variant } ;
11
+ use crate :: ChaChaCore ;
12
12
13
13
#[ cfg( feature = "cipher" ) ]
14
14
use crate :: chacha:: Block ;
@@ -19,13 +19,26 @@ use cipher::{
19
19
consts:: { U4 , U64 } ,
20
20
} ;
21
21
22
- struct Backend < R : Rounds > {
22
+ struct Backend < R : Rounds , V : Variant > {
23
23
state : [ uint32x4_t ; 4 ] ,
24
24
ctrs : [ uint32x4_t ; 4 ] ,
25
- _pd : PhantomData < R > ,
25
+ _pd : PhantomData < ( R , V ) > ,
26
26
}
27
27
28
- impl < R : Rounds > Backend < R > {
28
+ macro_rules! add_counter {
29
+ ( $a: expr, $b: expr, $variant: ty) => {
30
+ match size_of:: <<$variant>:: Counter >( ) {
31
+ 4 => vaddq_u32( $a, $b) ,
32
+ 8 => vreinterpretq_u32_u64( vaddq_u64(
33
+ vreinterpretq_u64_u32( $a) ,
34
+ vreinterpretq_u64_u32( $b) ,
35
+ ) ) ,
36
+ _ => unreachable!( ) ,
37
+ }
38
+ } ;
39
+ }
40
+
41
+ impl < R : Rounds , V : Variant > Backend < R , V > {
29
42
#[ inline]
30
43
unsafe fn new ( state : & mut [ u32 ; STATE_WORDS ] ) -> Self {
31
44
let state = [
@@ -40,7 +53,7 @@ impl<R: Rounds> Backend<R> {
40
53
vld1q_u32 ( [ 3 , 0 , 0 , 0 ] . as_ptr ( ) ) ,
41
54
vld1q_u32 ( [ 4 , 0 , 0 , 0 ] . as_ptr ( ) ) ,
42
55
] ;
43
- Backend :: < R > {
56
+ Backend :: < R , V > {
44
57
state,
45
58
ctrs,
46
59
_pd : PhantomData ,
@@ -51,16 +64,24 @@ impl<R: Rounds> Backend<R> {
51
64
#[ inline]
52
65
#[ cfg( feature = "cipher" ) ]
53
66
#[ target_feature( enable = "neon" ) ]
54
- pub ( crate ) unsafe fn inner < R , F > ( state : & mut [ u32 ; STATE_WORDS ] , f : F )
67
+ pub ( crate ) unsafe fn inner < R , F , V > ( state : & mut [ u32 ; STATE_WORDS ] , f : F )
55
68
where
56
69
R : Rounds ,
57
70
F : StreamCipherClosure < BlockSize = U64 > ,
71
+ V : Variant ,
58
72
{
59
- let mut backend = Backend :: < R > :: new ( state) ;
73
+ let mut backend = Backend :: < R , V > :: new ( state) ;
60
74
61
75
f. call ( & mut backend) ;
62
76
63
- vst1q_u32 ( state. as_mut_ptr ( ) . offset ( 12 ) , backend. state [ 3 ] ) ;
77
+ match size_of :: < V :: Counter > ( ) {
78
+ 4 => state[ 12 ] = vgetq_lane_u32 ( backend. state [ 3 ] , 0 ) ,
79
+ 8 => vst1q_u64 (
80
+ state. as_mut_ptr ( ) . offset ( 12 ) as * mut u64 ,
81
+ vreinterpretq_u64_u32 ( backend. state [ 3 ] ) ,
82
+ ) ,
83
+ _ => unreachable ! ( ) ,
84
+ }
64
85
}
65
86
66
87
#[ inline]
@@ -73,19 +94,22 @@ where
73
94
R : Rounds ,
74
95
V : Variant ,
75
96
{
76
- let mut backend = Backend :: < R > :: new ( & mut core. state ) ;
97
+ let mut backend = Backend :: < R , V > :: new ( & mut core. state ) ;
77
98
78
99
backend. write_par_ks_blocks ( buffer) ;
79
100
80
- vst1q_u32 ( core. state . as_mut_ptr ( ) . offset ( 12 ) , backend. state [ 3 ] ) ;
101
+ vst1q_u64 (
102
+ core. state . as_mut_ptr ( ) . offset ( 12 ) as * mut u64 ,
103
+ vreinterpretq_u64_u32 ( backend. state [ 3 ] ) ,
104
+ ) ;
81
105
}
82
106
83
107
#[ cfg( feature = "cipher" ) ]
84
- impl < R : Rounds > BlockSizeUser for Backend < R > {
108
+ impl < R : Rounds , V : Variant > BlockSizeUser for Backend < R , V > {
85
109
type BlockSize = U64 ;
86
110
}
87
111
#[ cfg( feature = "cipher" ) ]
88
- impl < R : Rounds > ParBlocksSizeUser for Backend < R > {
112
+ impl < R : Rounds , V : Variant > ParBlocksSizeUser for Backend < R , V > {
89
113
type ParBlocksSize = U4 ;
90
114
}
91
115
@@ -97,15 +121,15 @@ macro_rules! add_assign_vec {
97
121
}
98
122
99
123
#[ cfg( feature = "cipher" ) ]
100
- impl < R : Rounds > StreamCipherBackend for Backend < R > {
124
+ impl < R : Rounds , V : Variant > StreamCipherBackend for Backend < R , V > {
101
125
#[ inline( always) ]
102
126
fn gen_ks_block ( & mut self , block : & mut Block ) {
103
127
let state3 = self . state [ 3 ] ;
104
128
let mut par = ParBlocks :: < Self > :: default ( ) ;
105
129
self . gen_par_ks_blocks ( & mut par) ;
106
130
* block = par[ 0 ] ;
107
131
unsafe {
108
- self . state [ 3 ] = vaddq_u32 ( state3, vld1q_u32 ( [ 1 , 0 , 0 , 0 ] . as_ptr ( ) ) ) ;
132
+ self . state [ 3 ] = add_counter ! ( state3, vld1q_u32( [ 1 , 0 , 0 , 0 ] . as_ptr( ) ) , V ) ;
109
133
}
110
134
}
111
135
@@ -118,19 +142,19 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
118
142
self . state [ 0 ] ,
119
143
self . state [ 1 ] ,
120
144
self . state [ 2 ] ,
121
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 0 ] ) ,
145
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 0 ] , V ) ,
122
146
] ,
123
147
[
124
148
self . state [ 0 ] ,
125
149
self . state [ 1 ] ,
126
150
self . state [ 2 ] ,
127
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 1 ] ) ,
151
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 1 ] , V ) ,
128
152
] ,
129
153
[
130
154
self . state [ 0 ] ,
131
155
self . state [ 1 ] ,
132
156
self . state [ 2 ] ,
133
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 2 ] ) ,
157
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 2 ] , V ) ,
134
158
] ,
135
159
] ;
136
160
@@ -140,11 +164,16 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
140
164
141
165
for block in 0 ..4 {
142
166
// add state to block
143
- for state_row in 0 ..4 {
167
+ for state_row in 0 ..3 {
144
168
add_assign_vec ! ( blocks[ block] [ state_row] , self . state[ state_row] ) ;
145
169
}
146
170
if block > 0 {
147
- blocks[ block] [ 3 ] = vaddq_u32 ( blocks[ block] [ 3 ] , self . ctrs [ block - 1 ] ) ;
171
+ add_assign_vec ! (
172
+ blocks[ block] [ 3 ] ,
173
+ add_counter!( self . state[ 3 ] , self . ctrs[ block - 1 ] , V )
174
+ ) ;
175
+ } else {
176
+ add_assign_vec ! ( blocks[ block] [ 3 ] , self . state[ 3 ] ) ;
148
177
}
149
178
// write blocks to dest
150
179
for state_row in 0 ..4 {
@@ -154,7 +183,7 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
154
183
) ;
155
184
}
156
185
}
157
- self . state [ 3 ] = vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 3 ] ) ;
186
+ self . state [ 3 ] = add_counter ! ( self . state[ 3 ] , self . ctrs[ 3 ] , V ) ;
158
187
}
159
188
}
160
189
}
@@ -180,7 +209,7 @@ macro_rules! extract {
180
209
} ;
181
210
}
182
211
183
- impl < R : Rounds > Backend < R > {
212
+ impl < R : Rounds , V : Variant > Backend < R , V > {
184
213
#[ inline( always) ]
185
214
/// Generates `num_blocks` blocks and blindly writes them to `dest_ptr`
186
215
///
@@ -197,19 +226,19 @@ impl<R: Rounds> Backend<R> {
197
226
self . state [ 0 ] ,
198
227
self . state [ 1 ] ,
199
228
self . state [ 2 ] ,
200
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 0 ] ) ,
229
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 0 ] , V ) ,
201
230
] ,
202
231
[
203
232
self . state [ 0 ] ,
204
233
self . state [ 1 ] ,
205
234
self . state [ 2 ] ,
206
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 1 ] ) ,
235
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 1 ] , V ) ,
207
236
] ,
208
237
[
209
238
self . state [ 0 ] ,
210
239
self . state [ 1 ] ,
211
240
self . state [ 2 ] ,
212
- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 2 ] ) ,
241
+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 2 ] , V ) ,
213
242
] ,
214
243
] ;
215
244
@@ -220,11 +249,16 @@ impl<R: Rounds> Backend<R> {
220
249
let mut dest_ptr = buffer. as_mut_ptr ( ) as * mut u8 ;
221
250
for block in 0 ..4 {
222
251
// add state to block
223
- for state_row in 0 ..4 {
252
+ for state_row in 0 ..3 {
224
253
add_assign_vec ! ( blocks[ block] [ state_row] , self . state[ state_row] ) ;
225
254
}
226
255
if block > 0 {
227
- blocks[ block] [ 3 ] = vaddq_u32 ( blocks[ block] [ 3 ] , self . ctrs [ block - 1 ] ) ;
256
+ add_assign_vec ! (
257
+ blocks[ block] [ 3 ] ,
258
+ add_counter!( self . state[ 3 ] , self . ctrs[ block - 1 ] , V )
259
+ ) ;
260
+ } else {
261
+ add_assign_vec ! ( blocks[ block] [ 3 ] , self . state[ 3 ] ) ;
228
262
}
229
263
// write blocks to buffer
230
264
for state_row in 0 ..4 {
@@ -235,7 +269,7 @@ impl<R: Rounds> Backend<R> {
235
269
}
236
270
dest_ptr = dest_ptr. add ( 64 ) ;
237
271
}
238
- self . state [ 3 ] = vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 3 ] ) ;
272
+ self . state [ 3 ] = add_counter ! ( self . state[ 3 ] , self . ctrs[ 3 ] , V ) ;
239
273
}
240
274
}
241
275
0 commit comments