Skip to content

Commit 3302e43

Browse files
riptlripatel-fd
authored andcommitted
chacha20: add AVX-512 RNG implementation
1 parent e1dc541 commit 3302e43

File tree

6 files changed

+173
-42
lines changed

6 files changed

+173
-42
lines changed

src/ballet/chacha20/Local.mk

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
$(call add-hdrs,fd_chacha20.h fd_chacha20rng.h)
22
$(call add-objs,fd_chacha20rng,fd_ballet)
33

4+
ifdef FD_HAS_AVX512
5+
$(call add-objs,fd_chacha20_avx512,fd_ballet)
6+
endif
7+
48
ifdef FD_HAS_AVX
59
$(call add-objs,fd_chacha20_avx,fd_ballet)
610
endif

src/ballet/chacha20/fd_chacha20_avx.c

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,13 @@
99
static inline __attribute__((always_inline)) wu_t
1010
wu_rol8( wu_t x ) {
1111
wb_t const mask =
12-
wb( 3,0,1,2, 7,4,5,6, 11,8,9,10, 15,12,13,14,
13-
3,0,1,2, 7,4,5,6, 11,8,9,10, 15,12,13,14 );
12+
wb_bcast_hex( 3,0,1,2, 7,4,5,6, 11,8,9,10, 15,12,13,14 );
1413
return _mm256_shuffle_epi8( x, mask );
1514
}
1615

1716
void
1817
fd_chacha20rng_refill_avx( fd_chacha20rng_t * rng ) {
1918

20-
/* This function should only be called if the buffer is empty. */
21-
assert( rng->buf_off == rng->buf_fill );
22-
2319
wu_t iv0 = wu_bcast( 0x61707865U );
2420
wu_t iv1 = wu_bcast( 0x3320646eU );
2521
wu_t iv2 = wu_bcast( 0x79622d32U );
@@ -105,7 +101,8 @@ fd_chacha20rng_refill_avx( fd_chacha20rng_t * rng ) {
105101

106102
/* Update ring buffer */
107103

108-
uint * out = (uint *)rng->buf;
104+
ulong slot = rng->buf_fill % (8*FD_CHACHA20_BLOCK_SZ);
105+
uint * out = (uint *)rng->buf + (slot*2*FD_CHACHA20_BLOCK_SZ);
109106
wu_st( out+0x00, c0 ); wu_st( out+0x08, c8 );
110107
wu_st( out+0x10, c1 ); wu_st( out+0x18, c9 );
111108
wu_st( out+0x20, c2 ); wu_st( out+0x28, cA );
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include "fd_chacha20rng.h"
2+
#include "../../util/simd/fd_avx512.h"
3+
#include <assert.h>
4+
5+
#define wwu_rol16(a) wwb_exch_adj_pair( (a) )
6+
#define wwu_rol12(a) wwu_rol( (a), 12 )
7+
#define wwu_rol7(a) wwu_rol( (a), 7 )
8+
9+
static inline __attribute__((always_inline)) wwu_t
10+
wwu_rol8( wwu_t x ) {
11+
wwb_t const mask =
12+
wwb_bcast_hex( 3,0,1,2, 7,4,5,6, 11,8,9,10, 15,12,13,14 );
13+
return _mm512_shuffle_epi8( x, mask );
14+
}
15+
16+
void
17+
fd_chacha20rng_refill_avx512( fd_chacha20rng_t * rng ) {
18+
19+
/* This function should only be called if the buffer is empty. */
20+
if( FD_UNLIKELY( rng->buf_off != rng->buf_fill ) ) {
21+
FD_LOG_CRIT(( "refill out of sync: buf_off=%lu buf_fill=%lu", rng->buf_off, rng->buf_fill ));
22+
}
23+
24+
wwu_t iv0 = wwu_bcast( 0x61707865U );
25+
wwu_t iv1 = wwu_bcast( 0x3320646eU );
26+
wwu_t iv2 = wwu_bcast( 0x79622d32U );
27+
wwu_t iv3 = wwu_bcast( 0x6b206574U );
28+
wwu_t zero = wwu_zero();
29+
30+
/* Unpack key equivalent to:
31+
32+
c4 = wwu_bcast( (uint const *)(rng->key)[0] );
33+
c5 = wwu_bcast( (uint const *)(rng->key)[1] );
34+
...
35+
cB = wwu_bcast( (uint const *)(rng->key)[7] ); */
36+
37+
wwu_t key_lo = _mm512_broadcast_i32x4( _mm_load_epi32( rng->key ) ); /* [0,1,2,3,0,1,2,3] */
38+
wwu_t key_hi = _mm512_broadcast_i32x4( _mm_load_epi32( rng->key+16 ) ); /* [4,5,6,7,4,5,6,7] */
39+
wwu_t k0 = _mm512_shuffle_epi32( key_lo, 0x00 );
40+
wwu_t k1 = _mm512_shuffle_epi32( key_lo, 0x55 );
41+
wwu_t k2 = _mm512_shuffle_epi32( key_lo, 0xaa );
42+
wwu_t k3 = _mm512_shuffle_epi32( key_lo, 0xff );
43+
wwu_t k4 = _mm512_shuffle_epi32( key_hi, 0x00 );
44+
wwu_t k5 = _mm512_shuffle_epi32( key_hi, 0x55 );
45+
wwu_t k6 = _mm512_shuffle_epi32( key_hi, 0xaa );
46+
wwu_t k7 = _mm512_shuffle_epi32( key_hi, 0xff );
47+
48+
/* Derive block index */
49+
50+
ulong idx = rng->buf_fill / FD_CHACHA20_BLOCK_SZ; /* really a right shift */
51+
wwu_t idxs = wwu_add( wwu_bcast( idx ), wwu( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) );
52+
53+
/* Run through the round function */
54+
55+
wwu_t c0 = iv0; wwu_t c1 = iv1; wwu_t c2 = iv2; wwu_t c3 = iv3;
56+
wwu_t c4 = k0; wwu_t c5 = k1; wwu_t c6 = k2; wwu_t c7 = k3;
57+
wwu_t c8 = k4; wwu_t c9 = k5; wwu_t cA = k6; wwu_t cB = k7;
58+
wwu_t cC = idxs; wwu_t cD = zero; wwu_t cE = zero; wwu_t cF = zero;
59+
60+
# define QUARTER_ROUND(a,b,c,d) \
61+
do { \
62+
a = wwu_add( a, b ); d = wwu_xor( d, a ); d = wwu_rol16( d ); \
63+
c = wwu_add( c, d ); b = wwu_xor( b, c ); b = wwu_rol12( b ); \
64+
a = wwu_add( a, b ); d = wwu_xor( d, a ); d = wwu_rol8( d ); \
65+
c = wwu_add( c, d ); b = wwu_xor( b, c ); b = wwu_rol7( b ); \
66+
} while(0)
67+
68+
for( ulong i=0UL; i<10UL; i++ ) {
69+
QUARTER_ROUND( c0, c4, c8, cC );
70+
QUARTER_ROUND( c1, c5, c9, cD );
71+
QUARTER_ROUND( c2, c6, cA, cE );
72+
QUARTER_ROUND( c3, c7, cB, cF );
73+
QUARTER_ROUND( c0, c5, cA, cF );
74+
QUARTER_ROUND( c1, c6, cB, cC );
75+
QUARTER_ROUND( c2, c7, c8, cD );
76+
QUARTER_ROUND( c3, c4, c9, cE );
77+
}
78+
# undef QUARTER_ROUND
79+
80+
/* Finalize */
81+
82+
c0 = wwu_add( c0, iv0 );
83+
c1 = wwu_add( c1, iv1 );
84+
c2 = wwu_add( c2, iv2 );
85+
c3 = wwu_add( c3, iv3 );
86+
c4 = wwu_add( c4, k0 );
87+
c5 = wwu_add( c5, k1 );
88+
c6 = wwu_add( c6, k2 );
89+
c7 = wwu_add( c7, k3 );
90+
c8 = wwu_add( c8, k4 );
91+
c9 = wwu_add( c9, k5 );
92+
cA = wwu_add( cA, k6 );
93+
cB = wwu_add( cB, k7 );
94+
cC = wwu_add( cC, idxs );
95+
//cD = wwu_add( cD, zero );
96+
//cE = wwu_add( cE, zero );
97+
//cF = wwu_add( cF, zero );
98+
99+
/* Transpose matrix to get output vector */
100+
101+
wwu_transpose_16x16( c0, c1, c2, c3, c4, c5, c6, c7,
102+
c8, c9, cA, cB, cC, cD, cE, cF,
103+
c0, c1, c2, c3, c4, c5, c6, c7,
104+
c8, c9, cA, cB, cC, cD, cE, cF );
105+
106+
/* Update ring buffer */
107+
108+
uint * out = (uint *)rng->buf;
109+
wwu_st( out+0x00, c0 ); wwu_st( out+0x10, c1 );
110+
wwu_st( out+0x20, c2 ); wwu_st( out+0x30, c3 );
111+
wwu_st( out+0x40, c4 ); wwu_st( out+0x50, c5 );
112+
wwu_st( out+0x60, c6 ); wwu_st( out+0x70, c7 );
113+
wwu_st( out+0x80, c8 ); wwu_st( out+0x90, c9 );
114+
wwu_st( out+0xa0, cA ); wwu_st( out+0xb0, cB );
115+
wwu_st( out+0xc0, cC ); wwu_st( out+0xd0, cD );
116+
wwu_st( out+0xe0, cE ); wwu_st( out+0xf0, cF );
117+
118+
/* Update ring descriptor */
119+
120+
rng->buf_fill += 16*FD_CHACHA20_BLOCK_SZ;
121+
}

src/ballet/chacha20/fd_chacha20rng.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,6 @@ fd_chacha20rng_init( fd_chacha20rng_t * rng,
6868
return rng;
6969
}
7070

71-
#if FD_HAS_AVX
72-
73-
void
74-
fd_chacha20rng_refill_avx( fd_chacha20rng_t * rng );
75-
76-
#else
77-
7871
void
7972
fd_chacha20rng_refill_seq( fd_chacha20rng_t * rng ) {
8073
ulong fill_target = FD_CHACHA20RNG_BUFSZ - FD_CHACHA20_BLOCK_SZ;
@@ -90,5 +83,3 @@ fd_chacha20rng_refill_seq( fd_chacha20rng_t * rng ) {
9083
rng->buf_fill += (uint)FD_CHACHA20_BLOCK_SZ;
9184
}
9285
}
93-
94-
#endif

src/ballet/chacha20/fd_chacha20rng.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
/* FD_CHACHA20RNG_BUFSZ is the internal buffer size of pre-generated
2727
ChaCha20 blocks. Multiple of block size (64 bytes) and a power of 2. */
2828

29-
#if FD_HAS_AVX
29+
#if FD_HAS_AVX512
30+
#define FD_CHACHA20RNG_BUFSZ (16*FD_CHACHA20_BLOCK_SZ)
31+
#elif FD_HAS_AVX
3032
#define FD_CHACHA20RNG_BUFSZ (8*FD_CHACHA20_BLOCK_SZ)
3133
#else
3234
#define FD_CHACHA20RNG_BUFSZ (256UL)
@@ -116,15 +118,24 @@ fd_chacha20rng_t *
116118
fd_chacha20rng_init( fd_chacha20rng_t * rng,
117119
void const * key );
118120

119-
/* The refill function . Not part of the public API. */
121+
/* The refill function. Not part of the public API. */
120122

123+
#if FD_HAS_AVX512
124+
void
125+
fd_chacha20rng_refill_avx512( fd_chacha20rng_t * rng );
126+
#endif
127+
128+
#if FD_HAS_AVX
121129
void
122130
fd_chacha20rng_refill_avx( fd_chacha20rng_t * rng );
131+
#endif
123132

124133
void
125134
fd_chacha20rng_refill_seq( fd_chacha20rng_t * rng );
126135

127-
#if FD_HAS_AVX
136+
#if FD_HAS_AVX512
137+
#define fd_chacha20rng_private_refill fd_chacha20rng_refill_avx512
138+
#elif FD_HAS_AVX
128139
#define fd_chacha20rng_private_refill fd_chacha20rng_refill_avx
129140
#else
130141
#define fd_chacha20rng_private_refill fd_chacha20rng_refill_seq

src/ballet/chacha20/test_chacha20rng.c

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,42 @@ main( int argc,
5050
double gbps = ((double)(8UL*sizeof(ulong)*iter)) / ((double)dt);
5151
double ulongps = ((double)iter / (double)dt) * 1000.0;
5252
double ns = (double)dt / (double)iter;
53-
FD_LOG_NOTICE(( " ~%6.3f Gbps / core", gbps ));
53+
FD_LOG_NOTICE(( " ~%7.3f Gbps / core", gbps ));
5454
FD_LOG_NOTICE(( " ~%6.3f Mulong / second / core", ulongps ));
5555
FD_LOG_NOTICE(( " ~%6.3f ns / ulong", ns ));
5656
} while(0);
5757

58-
# if FD_HAS_AVX
59-
do {
60-
FD_LOG_NOTICE(( "Benchmarking fd_chacha20rng_refill_avx" ));
61-
key[ 0 ]++;
62-
FD_TEST( fd_chacha20rng_init( rng, key ) );
63-
64-
/* warmup */
65-
for( ulong rem=100000UL; rem; rem-- ) {
66-
rng->buf_off += 8*FD_CHACHA20_BLOCK_SZ;
67-
fd_chacha20rng_refill_avx( rng );
68-
}
69-
70-
/* for real */
71-
ulong iter = 1000000UL;
72-
long dt = -fd_log_wallclock();
73-
for( ulong rem=iter; rem; rem-- ) {
74-
rng->buf_off += 8*FD_CHACHA20_BLOCK_SZ;
75-
fd_chacha20rng_refill_avx( rng );
76-
}
77-
dt += fd_log_wallclock();
78-
double gbps = ((double)(8UL*8UL*FD_CHACHA20_BLOCK_SZ*iter)) / ((double)dt);
79-
FD_LOG_NOTICE(( " ~%6.3f Gbps / core", gbps ));
58+
#define REFILL_TEST( name, stride ) \
59+
do { \
60+
FD_LOG_NOTICE(( "Benchmarking " #name )); \
61+
key[ 0 ]++; \
62+
FD_TEST( fd_chacha20rng_init( rng, key ) ); \
63+
\
64+
/* warmup */ \
65+
for( ulong rem=100000UL; rem; rem-- ) { \
66+
rng->buf_off += (stride); \
67+
name( rng ); \
68+
} \
69+
\
70+
/* for real */ \
71+
ulong iter = 1000000UL; \
72+
long dt = -fd_log_wallclock(); \
73+
for( ulong rem=iter; rem; rem-- ) { \
74+
rng->buf_off += (stride); \
75+
name( rng ); \
76+
} \
77+
dt += fd_log_wallclock(); \
78+
double gbps = ((double)(8UL*(stride)*iter)) / ((double)dt); \
79+
FD_LOG_NOTICE(( " ~%7.3f Gbps / core", gbps )); \
8080
} while(0);
81-
# endif /* FD_HAS_AVX */
81+
82+
# if FD_HAS_AVX512
83+
REFILL_TEST( fd_chacha20rng_refill_avx512, 16*FD_CHACHA20_BLOCK_SZ );
84+
# endif
85+
# if FD_HAS_AVX
86+
REFILL_TEST( fd_chacha20rng_refill_avx, 8*FD_CHACHA20_BLOCK_SZ );
87+
# endif
88+
REFILL_TEST( fd_chacha20rng_refill_seq, 1*FD_CHACHA20_BLOCK_SZ );
8289

8390
/* Clean up */
8491

0 commit comments

Comments
 (0)