Skip to content

Commit 078b648

Browse files
nstilt1newpavlov
andauthored
chacha20: 64-bit counter support (#439)
Closes #334 Co-authored-by: Артём Павлов [Artyom Pavlov] <[email protected]>
1 parent 33a80ac commit 078b648

File tree

8 files changed

+515
-280
lines changed

8 files changed

+515
-280
lines changed

chacha20/src/backends/avx2.rs

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#![allow(unsafe_op_in_unsafe_fn)]
2-
use crate::Rounds;
2+
use crate::{Rounds, Variant};
33
use core::marker::PhantomData;
44

55
#[cfg(feature = "rng")]
6-
use crate::{ChaChaCore, Variant};
6+
use crate::ChaChaCore;
77

88
#[cfg(feature = "cipher")]
99
use crate::{chacha::Block, STATE_WORDS};
@@ -27,10 +27,11 @@ const N: usize = PAR_BLOCKS / 2;
2727
#[inline]
2828
#[target_feature(enable = "avx2")]
2929
#[cfg(feature = "cipher")]
30-
pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F)
30+
pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
3131
where
3232
R: Rounds,
3333
F: StreamCipherClosure<BlockSize = U64>,
34+
V: Variant,
3435
{
3536
let state_ptr = state.as_ptr() as *const __m128i;
3637
let v = [
@@ -39,13 +40,21 @@ where
3940
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
4041
];
4142
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
42-
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
43+
c = match size_of::<V::Counter>() {
44+
4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)),
45+
8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)),
46+
_ => unreachable!()
47+
};
4348
let mut ctr = [c; N];
4449
for i in 0..N {
4550
ctr[i] = c;
46-
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
51+
c = match size_of::<V::Counter>() {
52+
4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)),
53+
8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)),
54+
_ => unreachable!(),
55+
};
4756
}
48-
let mut backend = Backend::<R> {
57+
let mut backend = Backend::<R, V> {
4958
v,
5059
ctr,
5160
_pd: PhantomData,
@@ -54,6 +63,11 @@ where
5463
f.call(&mut backend);
5564

5665
state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
66+
match size_of::<V::Counter>() {
67+
4 => {},
68+
8 => state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32,
69+
_ => unreachable!()
70+
}
5771
}
5872

5973
#[inline]
@@ -71,13 +85,13 @@ where
7185
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
7286
];
7387
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
74-
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
88+
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
7589
let mut ctr = [c; N];
7690
for i in 0..N {
7791
ctr[i] = c;
78-
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
92+
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
7993
}
80-
let mut backend = Backend::<R> {
94+
let mut backend = Backend::<R, V> {
8195
v,
8296
ctr,
8397
_pd: PhantomData,
@@ -86,32 +100,37 @@ where
86100
backend.rng_gen_par_ks_blocks(buffer);
87101

88102
core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
103+
core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
89104
}
90105

91-
struct Backend<R: Rounds> {
106+
struct Backend<R: Rounds, V: Variant> {
92107
v: [__m256i; 3],
93108
ctr: [__m256i; N],
94-
_pd: PhantomData<R>,
109+
_pd: PhantomData<(R, V)>,
95110
}
96111

97112
#[cfg(feature = "cipher")]
98-
impl<R: Rounds> BlockSizeUser for Backend<R> {
113+
impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
99114
type BlockSize = U64;
100115
}
101116

102117
#[cfg(feature = "cipher")]
103-
impl<R: Rounds> ParBlocksSizeUser for Backend<R> {
118+
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
104119
type ParBlocksSize = U4;
105120
}
106121

107122
#[cfg(feature = "cipher")]
108-
impl<R: Rounds> StreamCipherBackend for Backend<R> {
123+
impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
109124
#[inline(always)]
110125
fn gen_ks_block(&mut self, block: &mut Block) {
111126
unsafe {
112127
let res = rounds::<R>(&self.v, &self.ctr);
113128
for c in self.ctr.iter_mut() {
114-
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1));
129+
*c = match size_of::<V::Counter>() {
130+
4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1)),
131+
8 => _mm256_add_epi64(*c, _mm256_set_epi64x(0, 1, 0, 1)),
132+
_ => unreachable!()
133+
};
115134
}
116135

117136
let res0: [__m128i; 8] = core::mem::transmute(res[0]);
@@ -130,7 +149,11 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
130149

131150
let pb = PAR_BLOCKS as i32;
132151
for c in self.ctr.iter_mut() {
133-
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
152+
*c = match size_of::<V::Counter>() {
153+
4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb)),
154+
8 => _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64)),
155+
_ => unreachable!()
156+
}
134157
}
135158

136159
let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
@@ -147,15 +170,15 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
147170
}
148171

149172
#[cfg(feature = "rng")]
150-
impl<R: Rounds> Backend<R> {
173+
impl<R: Rounds, V: Variant> Backend<R, V> {
151174
#[inline(always)]
152175
fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) {
153176
unsafe {
154177
let vs = rounds::<R>(&self.v, &self.ctr);
155178

156179
let pb = PAR_BLOCKS as i32;
157180
for c in self.ctr.iter_mut() {
158-
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
181+
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64));
159182
}
160183

161184
let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;

chacha20/src/backends/neon.rs

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
//! Adapted from the Crypto++ `chacha_simd` implementation by Jack Lloyd and
55
//! Jeffrey Walton (public domain).
66
7-
use crate::{Rounds, STATE_WORDS};
7+
use crate::{Rounds, STATE_WORDS, Variant};
88
use core::{arch::aarch64::*, marker::PhantomData};
99

1010
#[cfg(feature = "rand_core")]
11-
use crate::{ChaChaCore, Variant};
11+
use crate::ChaChaCore;
1212

1313
#[cfg(feature = "cipher")]
1414
use crate::chacha::Block;
@@ -19,13 +19,26 @@ use cipher::{
1919
consts::{U4, U64},
2020
};
2121

22-
struct Backend<R: Rounds> {
22+
struct Backend<R: Rounds, V: Variant> {
2323
state: [uint32x4_t; 4],
2424
ctrs: [uint32x4_t; 4],
25-
_pd: PhantomData<R>,
25+
_pd: PhantomData<(R, V)>,
2626
}
2727

28-
impl<R: Rounds> Backend<R> {
28+
macro_rules! add_counter {
29+
($a:expr, $b:expr, $variant:ty) => {
30+
match size_of::<<$variant>::Counter>() {
31+
4 => vaddq_u32($a, $b),
32+
8 => vreinterpretq_u32_u64(vaddq_u64(
33+
vreinterpretq_u64_u32($a),
34+
vreinterpretq_u64_u32($b),
35+
)),
36+
_ => unreachable!(),
37+
}
38+
};
39+
}
40+
41+
impl<R: Rounds, V: Variant> Backend<R, V> {
2942
#[inline]
3043
unsafe fn new(state: &mut [u32; STATE_WORDS]) -> Self {
3144
let state = [
@@ -40,7 +53,7 @@ impl<R: Rounds> Backend<R> {
4053
vld1q_u32([3, 0, 0, 0].as_ptr()),
4154
vld1q_u32([4, 0, 0, 0].as_ptr()),
4255
];
43-
Backend::<R> {
56+
Backend::<R, V> {
4457
state,
4558
ctrs,
4659
_pd: PhantomData,
@@ -51,16 +64,24 @@ impl<R: Rounds> Backend<R> {
5164
#[inline]
5265
#[cfg(feature = "cipher")]
5366
#[target_feature(enable = "neon")]
54-
pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F)
67+
pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
5568
where
5669
R: Rounds,
5770
F: StreamCipherClosure<BlockSize = U64>,
71+
V: Variant,
5872
{
59-
let mut backend = Backend::<R>::new(state);
73+
let mut backend = Backend::<R, V>::new(state);
6074

6175
f.call(&mut backend);
6276

63-
vst1q_u32(state.as_mut_ptr().offset(12), backend.state[3]);
77+
match size_of::<V::Counter>() {
78+
4 => state[12] = vgetq_lane_u32(backend.state[3], 0),
79+
8 => vst1q_u64(
80+
state.as_mut_ptr().offset(12) as *mut u64,
81+
vreinterpretq_u64_u32(backend.state[3]),
82+
),
83+
_ => unreachable!(),
84+
}
6485
}
6586

6687
#[inline]
@@ -73,19 +94,22 @@ where
7394
R: Rounds,
7495
V: Variant,
7596
{
76-
let mut backend = Backend::<R>::new(&mut core.state);
97+
let mut backend = Backend::<R, V>::new(&mut core.state);
7798

7899
backend.write_par_ks_blocks(buffer);
79100

80-
vst1q_u32(core.state.as_mut_ptr().offset(12), backend.state[3]);
101+
vst1q_u64(
102+
core.state.as_mut_ptr().offset(12) as *mut u64,
103+
vreinterpretq_u64_u32(backend.state[3]),
104+
);
81105
}
82106

83107
#[cfg(feature = "cipher")]
84-
impl<R: Rounds> BlockSizeUser for Backend<R> {
108+
impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
85109
type BlockSize = U64;
86110
}
87111
#[cfg(feature = "cipher")]
88-
impl<R: Rounds> ParBlocksSizeUser for Backend<R> {
112+
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
89113
type ParBlocksSize = U4;
90114
}
91115

@@ -97,15 +121,15 @@ macro_rules! add_assign_vec {
97121
}
98122

99123
#[cfg(feature = "cipher")]
100-
impl<R: Rounds> StreamCipherBackend for Backend<R> {
124+
impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
101125
#[inline(always)]
102126
fn gen_ks_block(&mut self, block: &mut Block) {
103127
let state3 = self.state[3];
104128
let mut par = ParBlocks::<Self>::default();
105129
self.gen_par_ks_blocks(&mut par);
106130
*block = par[0];
107131
unsafe {
108-
self.state[3] = vaddq_u32(state3, vld1q_u32([1, 0, 0, 0].as_ptr()));
132+
self.state[3] = add_counter!(state3, vld1q_u32([1, 0, 0, 0].as_ptr()), V);
109133
}
110134
}
111135

@@ -118,19 +142,19 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
118142
self.state[0],
119143
self.state[1],
120144
self.state[2],
121-
vaddq_u32(self.state[3], self.ctrs[0]),
145+
add_counter!(self.state[3], self.ctrs[0], V),
122146
],
123147
[
124148
self.state[0],
125149
self.state[1],
126150
self.state[2],
127-
vaddq_u32(self.state[3], self.ctrs[1]),
151+
add_counter!(self.state[3], self.ctrs[1], V),
128152
],
129153
[
130154
self.state[0],
131155
self.state[1],
132156
self.state[2],
133-
vaddq_u32(self.state[3], self.ctrs[2]),
157+
add_counter!(self.state[3], self.ctrs[2], V),
134158
],
135159
];
136160

@@ -140,11 +164,16 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
140164

141165
for block in 0..4 {
142166
// add state to block
143-
for state_row in 0..4 {
167+
for state_row in 0..3 {
144168
add_assign_vec!(blocks[block][state_row], self.state[state_row]);
145169
}
146170
if block > 0 {
147-
blocks[block][3] = vaddq_u32(blocks[block][3], self.ctrs[block - 1]);
171+
add_assign_vec!(
172+
blocks[block][3],
173+
add_counter!(self.state[3], self.ctrs[block - 1], V)
174+
);
175+
} else {
176+
add_assign_vec!(blocks[block][3], self.state[3]);
148177
}
149178
// write blocks to dest
150179
for state_row in 0..4 {
@@ -154,7 +183,7 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
154183
);
155184
}
156185
}
157-
self.state[3] = vaddq_u32(self.state[3], self.ctrs[3]);
186+
self.state[3] = add_counter!(self.state[3], self.ctrs[3], V);
158187
}
159188
}
160189
}
@@ -180,7 +209,7 @@ macro_rules! extract {
180209
};
181210
}
182211

183-
impl<R: Rounds> Backend<R> {
212+
impl<R: Rounds, V: Variant> Backend<R, V> {
184213
#[inline(always)]
185214
/// Generates `num_blocks` blocks and blindly writes them to `dest_ptr`
186215
///
@@ -197,19 +226,19 @@ impl<R: Rounds> Backend<R> {
197226
self.state[0],
198227
self.state[1],
199228
self.state[2],
200-
vaddq_u32(self.state[3], self.ctrs[0]),
229+
add_counter!(self.state[3], self.ctrs[0], V),
201230
],
202231
[
203232
self.state[0],
204233
self.state[1],
205234
self.state[2],
206-
vaddq_u32(self.state[3], self.ctrs[1]),
235+
add_counter!(self.state[3], self.ctrs[1], V),
207236
],
208237
[
209238
self.state[0],
210239
self.state[1],
211240
self.state[2],
212-
vaddq_u32(self.state[3], self.ctrs[2]),
241+
add_counter!(self.state[3], self.ctrs[2], V),
213242
],
214243
];
215244

@@ -220,11 +249,16 @@ impl<R: Rounds> Backend<R> {
220249
let mut dest_ptr = buffer.as_mut_ptr() as *mut u8;
221250
for block in 0..4 {
222251
// add state to block
223-
for state_row in 0..4 {
252+
for state_row in 0..3 {
224253
add_assign_vec!(blocks[block][state_row], self.state[state_row]);
225254
}
226255
if block > 0 {
227-
blocks[block][3] = vaddq_u32(blocks[block][3], self.ctrs[block - 1]);
256+
add_assign_vec!(
257+
blocks[block][3],
258+
add_counter!(self.state[3], self.ctrs[block - 1], V)
259+
);
260+
} else {
261+
add_assign_vec!(blocks[block][3], self.state[3]);
228262
}
229263
// write blocks to buffer
230264
for state_row in 0..4 {
@@ -235,7 +269,7 @@ impl<R: Rounds> Backend<R> {
235269
}
236270
dest_ptr = dest_ptr.add(64);
237271
}
238-
self.state[3] = vaddq_u32(self.state[3], self.ctrs[3]);
272+
self.state[3] = add_counter!(self.state[3], self.ctrs[3], V);
239273
}
240274
}
241275

0 commit comments

Comments
 (0)