Skip to content

Commit e5ed5eb

Browse files
committed
SIMD implementation of md5
1 parent 0fcfe00 commit e5ed5eb

File tree

5 files changed

+239
-3
lines changed

5 files changed

+239
-3
lines changed

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ jobs:
1212
steps:
1313
- uses: actions/checkout@v4
1414
- run: cargo fmt --check
15-
- run: cargo clippy --tests --all-features -- --deny warnings
15+
- run: cargo clippy --tests --features frivolity -- --deny warnings
1616
- run: cargo test

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ rust-version = "1.80"
66

77
[features]
88
frivolity = []
9+
simd = []
910

1011
[lints.rustdoc]
1112
private_intra_doc_links = "allow"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Complete 2023 to 2015 entries for the annual [Advent of Code] challenge, written
99
* Consistently formatted with `rustfmt` and linted by `clippy`.
1010
* Thoroughly commented with `rustdoc` generated [documentation online][docs-link].
1111
* Test coverage with continuous integration provided by [GitHub Actions][checks-link].
12-
* Self contained depending only on the stable `std` Rust library. No use of `unsafe` features.
12+
* Self contained depending only on the `std` Rust library. No use of `unsafe` features.
1313

1414
## Quickstart
1515

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
//! [badge]: https://img.shields.io/badge/github-blue?style=for-the-badge&logo=github&labelColor=grey
66
//! [link]: https://github.com/maneatingape/advent-of-code-rust
77
8-
//! <!-- Configure rustdoc -->
8+
// Portable SIMD API is enabled by "simd" feature.
9+
#![cfg_attr(feature = "simd", allow(unstable_features), feature(portable_simd))]
10+
// Configure rustdoc.
911
#![doc(html_logo_url = "https://maneatingape.github.io/advent-of-code-rust/logo.png")]
1012

1113
/// # Utility modules to handle common recurring Advent of Code patterns.

src/util/md5.rs

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
//! To maximize speed the loop for each of the four rounds used to create the hash is unrolled and
1010
//! all internal utility functions marked as
1111
//! [`#[inline]`](https://doc.rust-lang.org/reference/attributes/codegen.html#the-inline-attribute).
12+
//!
13+
//! An optional SIMD variant that computes multiple hashes in parallel is also implemented.
1214
1315
pub fn buffer_size(n: usize) -> usize {
1416
(n + 9).next_multiple_of(64)
@@ -145,3 +147,234 @@ fn round4(a: u32, b: u32, c: u32, d: u32, m: u32, s: u32, k: u32) -> u32 {
145147
fn common(f: u32, a: u32, b: u32, m: u32, s: u32, k: u32) -> u32 {
146148
f.wrapping_add(a).wrapping_add(k).wrapping_add(m).rotate_left(s).wrapping_add(b)
147149
}
150+
151+
#[cfg(feature = "simd")]
152+
pub mod simd {
153+
use std::array;
154+
use std::simd::num::SimdUint;
155+
use std::simd::{LaneCount, Simd, SupportedLaneCount};
156+
157+
#[inline]
158+
#[allow(clippy::too_many_lines)]
159+
pub fn hash<const N: usize>(
160+
buffers: &mut [[u8; 64]],
161+
size: usize,
162+
) -> ([u32; N], [u32; N], [u32; N], [u32; N])
163+
where
164+
LaneCount<N>: SupportedLaneCount,
165+
{
166+
// Assume all buffers are the same size.
167+
let end = 64 - 8;
168+
let bits = size * 8;
169+
170+
for buffer in buffers.iter_mut() {
171+
buffer[size] = 0x80;
172+
buffer[end..].copy_from_slice(&bits.to_le_bytes());
173+
}
174+
175+
let mut a0: Simd<u32, N> = Simd::splat(0x67452301);
176+
let mut b0: Simd<u32, N> = Simd::splat(0xefcdab89);
177+
let mut c0: Simd<u32, N> = Simd::splat(0x98badcfe);
178+
let mut d0: Simd<u32, N> = Simd::splat(0x10325476);
179+
180+
let mut a = a0;
181+
let mut b = b0;
182+
let mut c = c0;
183+
let mut d = d0;
184+
185+
let m0 = message(buffers, 0);
186+
a = round1(a, b, c, d, m0, 7, 0xd76aa478);
187+
let m1 = message(buffers, 1);
188+
d = round1(d, a, b, c, m1, 12, 0xe8c7b756);
189+
let m2 = message(buffers, 2);
190+
c = round1(c, d, a, b, m2, 17, 0x242070db);
191+
let m3 = message(buffers, 3);
192+
b = round1(b, c, d, a, m3, 22, 0xc1bdceee);
193+
let m4 = message(buffers, 4);
194+
a = round1(a, b, c, d, m4, 7, 0xf57c0faf);
195+
let m5 = message(buffers, 5);
196+
d = round1(d, a, b, c, m5, 12, 0x4787c62a);
197+
let m6 = message(buffers, 6);
198+
c = round1(c, d, a, b, m6, 17, 0xa8304613);
199+
let m7 = message(buffers, 7);
200+
b = round1(b, c, d, a, m7, 22, 0xfd469501);
201+
let m8 = message(buffers, 8);
202+
a = round1(a, b, c, d, m8, 7, 0x698098d8);
203+
let m9 = message(buffers, 9);
204+
d = round1(d, a, b, c, m9, 12, 0x8b44f7af);
205+
let m10 = message(buffers, 10);
206+
c = round1(c, d, a, b, m10, 17, 0xffff5bb1);
207+
let m11 = message(buffers, 11);
208+
b = round1(b, c, d, a, m11, 22, 0x895cd7be);
209+
let m12 = message(buffers, 12);
210+
a = round1(a, b, c, d, m12, 7, 0x6b901122);
211+
let m13 = message(buffers, 13);
212+
d = round1(d, a, b, c, m13, 12, 0xfd987193);
213+
let m14 = message(buffers, 14);
214+
c = round1(c, d, a, b, m14, 17, 0xa679438e);
215+
let m15 = message(buffers, 15);
216+
b = round1(b, c, d, a, m15, 22, 0x49b40821);
217+
218+
a = round2(a, b, c, d, m1, 5, 0xf61e2562);
219+
d = round2(d, a, b, c, m6, 9, 0xc040b340);
220+
c = round2(c, d, a, b, m11, 14, 0x265e5a51);
221+
b = round2(b, c, d, a, m0, 20, 0xe9b6c7aa);
222+
a = round2(a, b, c, d, m5, 5, 0xd62f105d);
223+
d = round2(d, a, b, c, m10, 9, 0x02441453);
224+
c = round2(c, d, a, b, m15, 14, 0xd8a1e681);
225+
b = round2(b, c, d, a, m4, 20, 0xe7d3fbc8);
226+
a = round2(a, b, c, d, m9, 5, 0x21e1cde6);
227+
d = round2(d, a, b, c, m14, 9, 0xc33707d6);
228+
c = round2(c, d, a, b, m3, 14, 0xf4d50d87);
229+
b = round2(b, c, d, a, m8, 20, 0x455a14ed);
230+
a = round2(a, b, c, d, m13, 5, 0xa9e3e905);
231+
d = round2(d, a, b, c, m2, 9, 0xfcefa3f8);
232+
c = round2(c, d, a, b, m7, 14, 0x676f02d9);
233+
b = round2(b, c, d, a, m12, 20, 0x8d2a4c8a);
234+
235+
a = round3(a, b, c, d, m5, 4, 0xfffa3942);
236+
d = round3(d, a, b, c, m8, 11, 0x8771f681);
237+
c = round3(c, d, a, b, m11, 16, 0x6d9d6122);
238+
b = round3(b, c, d, a, m14, 23, 0xfde5380c);
239+
a = round3(a, b, c, d, m1, 4, 0xa4beea44);
240+
d = round3(d, a, b, c, m4, 11, 0x4bdecfa9);
241+
c = round3(c, d, a, b, m7, 16, 0xf6bb4b60);
242+
b = round3(b, c, d, a, m10, 23, 0xbebfbc70);
243+
a = round3(a, b, c, d, m13, 4, 0x289b7ec6);
244+
d = round3(d, a, b, c, m0, 11, 0xeaa127fa);
245+
c = round3(c, d, a, b, m3, 16, 0xd4ef3085);
246+
b = round3(b, c, d, a, m6, 23, 0x04881d05);
247+
a = round3(a, b, c, d, m9, 4, 0xd9d4d039);
248+
d = round3(d, a, b, c, m12, 11, 0xe6db99e5);
249+
c = round3(c, d, a, b, m15, 16, 0x1fa27cf8);
250+
b = round3(b, c, d, a, m2, 23, 0xc4ac5665);
251+
252+
a = round4(a, b, c, d, m0, 6, 0xf4292244);
253+
d = round4(d, a, b, c, m7, 10, 0x432aff97);
254+
c = round4(c, d, a, b, m14, 15, 0xab9423a7);
255+
b = round4(b, c, d, a, m5, 21, 0xfc93a039);
256+
a = round4(a, b, c, d, m12, 6, 0x655b59c3);
257+
d = round4(d, a, b, c, m3, 10, 0x8f0ccc92);
258+
c = round4(c, d, a, b, m10, 15, 0xffeff47d);
259+
b = round4(b, c, d, a, m1, 21, 0x85845dd1);
260+
a = round4(a, b, c, d, m8, 6, 0x6fa87e4f);
261+
d = round4(d, a, b, c, m15, 10, 0xfe2ce6e0);
262+
c = round4(c, d, a, b, m6, 15, 0xa3014314);
263+
b = round4(b, c, d, a, m13, 21, 0x4e0811a1);
264+
a = round4(a, b, c, d, m4, 6, 0xf7537e82);
265+
d = round4(d, a, b, c, m11, 10, 0xbd3af235);
266+
c = round4(c, d, a, b, m2, 15, 0x2ad7d2bb);
267+
b = round4(b, c, d, a, m9, 21, 0xeb86d391);
268+
269+
a0 += a;
270+
b0 += b;
271+
c0 += c;
272+
d0 += d;
273+
274+
(
275+
a0.swap_bytes().to_array(),
276+
b0.swap_bytes().to_array(),
277+
c0.swap_bytes().to_array(),
278+
d0.swap_bytes().to_array(),
279+
)
280+
}
281+
282+
#[inline]
283+
fn message<const N: usize>(buffers: &mut [[u8; 64]], i: usize) -> Simd<u32, N>
284+
where
285+
LaneCount<N>: SupportedLaneCount,
286+
{
287+
let start = 4 * i;
288+
let end = start + 4;
289+
Simd::from_array(array::from_fn(|lane| {
290+
let slice = &buffers[lane][start..end];
291+
u32::from_le_bytes(slice.try_into().unwrap())
292+
}))
293+
}
294+
295+
#[inline]
296+
fn round1<const N: usize>(
297+
a: Simd<u32, N>,
298+
b: Simd<u32, N>,
299+
c: Simd<u32, N>,
300+
d: Simd<u32, N>,
301+
m: Simd<u32, N>,
302+
s: u32,
303+
k: u32,
304+
) -> Simd<u32, N>
305+
where
306+
LaneCount<N>: SupportedLaneCount,
307+
{
308+
let f = (b & c) | (!b & d);
309+
common(f, a, b, m, s, k)
310+
}
311+
312+
#[inline]
313+
fn round2<const N: usize>(
314+
a: Simd<u32, N>,
315+
b: Simd<u32, N>,
316+
c: Simd<u32, N>,
317+
d: Simd<u32, N>,
318+
m: Simd<u32, N>,
319+
s: u32,
320+
k: u32,
321+
) -> Simd<u32, N>
322+
where
323+
LaneCount<N>: SupportedLaneCount,
324+
{
325+
let f = (b & d) | (c & !d);
326+
common(f, a, b, m, s, k)
327+
}
328+
329+
#[inline]
330+
fn round3<const N: usize>(
331+
a: Simd<u32, N>,
332+
b: Simd<u32, N>,
333+
c: Simd<u32, N>,
334+
d: Simd<u32, N>,
335+
m: Simd<u32, N>,
336+
s: u32,
337+
k: u32,
338+
) -> Simd<u32, N>
339+
where
340+
LaneCount<N>: SupportedLaneCount,
341+
{
342+
let f = b ^ c ^ d;
343+
common(f, a, b, m, s, k)
344+
}
345+
346+
#[inline]
347+
fn round4<const N: usize>(
348+
a: Simd<u32, N>,
349+
b: Simd<u32, N>,
350+
c: Simd<u32, N>,
351+
d: Simd<u32, N>,
352+
m: Simd<u32, N>,
353+
s: u32,
354+
k: u32,
355+
) -> Simd<u32, N>
356+
where
357+
LaneCount<N>: SupportedLaneCount,
358+
{
359+
let f = c ^ (b | !d);
360+
common(f, a, b, m, s, k)
361+
}
362+
363+
#[inline]
364+
fn common<const N: usize>(
365+
f: Simd<u32, N>,
366+
a: Simd<u32, N>,
367+
b: Simd<u32, N>,
368+
m: Simd<u32, N>,
369+
s: u32,
370+
k: u32,
371+
) -> Simd<u32, N>
372+
where
373+
LaneCount<N>: SupportedLaneCount,
374+
{
375+
let k = Simd::splat(k);
376+
let first = f + a + k + m;
377+
let second = (first << s) | (first >> (32 - s));
378+
second + b
379+
}
380+
}

0 commit comments

Comments
 (0)