Skip to content

Commit 47cabfd

Browse files
ibmp33eigmax
andauthored
chore: add avx2 acceleration to poseidon hash function (#155)
* chore: add avx2 acceleration to poseidon hash function * fix: add more sample to test overflow(mul/squre) * chore: add avx2 acceleration to poseidon hash function * fix: remove warnings * fix: overflow --------- Co-authored-by: eigmax <[email protected]>
1 parent 4ed1da7 commit 47cabfd

File tree

15 files changed

+1495
-113
lines changed

15 files changed

+1495
-113
lines changed

algebraic/src/arch/x86_64/avx2_field_gl.rs

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
//!
77
use crate::ff::*;
88
use crate::field_gl::{Fr, FrRepr as GoldilocksField};
9+
use crate::packed::PackedField;
910
use core::arch::x86_64::*;
1011
use core::fmt;
1112
use core::fmt::{Debug, Formatter};
1213
use core::mem::transmute;
1314
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
14-
// use crate::packed::PackedField;
1515

1616
/// AVX2 Goldilocks Field
1717
///
@@ -24,8 +24,6 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAss
2424
#[repr(transparent)]
2525
pub struct Avx2GoldilocksField(pub [GoldilocksField; 4]);
2626

27-
const WIDTH: usize = 4;
28-
2927
impl Avx2GoldilocksField {
3028
#[inline]
3129
pub fn new(x: __m256i) -> Self {
@@ -35,30 +33,40 @@ impl Avx2GoldilocksField {
3533
pub fn get(&self) -> __m256i {
3634
unsafe { transmute(*self) }
3735
}
38-
// }
39-
// unsafe impl PackedField for Avx2GoldilocksField {
4036
#[inline]
41-
pub fn from_slice(slice: &[GoldilocksField]) -> &Self {
42-
assert_eq!(slice.len(), WIDTH);
37+
pub fn square(&self) -> Avx2GoldilocksField {
38+
Self::new(unsafe { square(self.get()) })
39+
}
40+
#[inline]
41+
pub fn reduce(x: __m256i, y: __m256i) -> Avx2GoldilocksField {
42+
Self::new(unsafe { reduce128((x, y)) })
43+
}
44+
}
45+
46+
unsafe impl PackedField for Avx2GoldilocksField {
47+
const WIDTH: usize = 4;
48+
type Scalar = GoldilocksField;
49+
const ZEROS: Self = Self([GoldilocksField([0]); 4]);
50+
const ONES: Self = Self([GoldilocksField([1]); 4]);
51+
52+
#[inline]
53+
fn from_slice(slice: &[GoldilocksField]) -> &Self {
54+
assert_eq!(slice.len(), Self::WIDTH);
4355
unsafe { &*slice.as_ptr().cast() }
4456
}
4557
#[inline]
46-
pub fn from_slice_mut(slice: &mut [GoldilocksField]) -> &mut Self {
47-
assert_eq!(slice.len(), WIDTH);
58+
fn from_slice_mut(slice: &mut [GoldilocksField]) -> &mut Self {
59+
assert_eq!(slice.len(), Self::WIDTH);
4860
unsafe { &mut *slice.as_mut_ptr().cast() }
4961
}
5062
#[inline]
51-
pub fn as_slice(&self) -> &[GoldilocksField] {
63+
fn as_slice(&self) -> &[GoldilocksField] {
5264
&self.0[..]
5365
}
5466
#[inline]
55-
pub fn as_slice_mut(&mut self) -> &mut [GoldilocksField] {
67+
fn as_slice_mut(&mut self) -> &mut [GoldilocksField] {
5668
&mut self.0[..]
5769
}
58-
#[inline]
59-
pub fn square(&self) -> Avx2GoldilocksField {
60-
Self::new(unsafe { square(self.get()) })
61-
}
6270

6371
#[inline]
6472
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
@@ -117,7 +125,7 @@ impl Debug for Avx2GoldilocksField {
117125
impl Default for Avx2GoldilocksField {
118126
#[inline]
119127
fn default() -> Self {
120-
Self([GoldilocksField::from(0); 4])
128+
Self::ZEROS
121129
}
122130
}
123131

@@ -325,7 +333,8 @@ unsafe fn add_no_double_overflow_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i {
325333
unsafe fn add(x: __m256i, y: __m256i) -> __m256i {
326334
let y_s = shift(y);
327335
let res_s = add_no_double_overflow_64_64s_s(x, canonicalize_s(y_s));
328-
shift(res_s)
336+
// Added by Eigen
337+
shift(canonicalize_s(res_s))
329338
}
330339

331340
#[inline]
@@ -455,7 +464,8 @@ unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i {
455464
let lo1_s = sub_small_64s_64_s(lo0_s, hi_hi0);
456465
let t1 = _mm256_mul_epu32(hi0, EPSILON);
457466
let lo2_s = add_small_64s_64_s(lo1_s, t1);
458-
let lo2 = shift(lo2_s);
467+
// Added by Eigen
468+
let lo2 = shift(canonicalize_s(lo2_s));
459469
lo2
460470
}
461471

@@ -503,20 +513,20 @@ mod tests {
503513
use super::Avx2GoldilocksField;
504514
use crate::ff::*;
505515
use crate::field_gl::{Fr, FrRepr as GoldilocksField};
516+
use crate::packed::PackedField;
506517
use std::time::Instant;
507-
// use crate::packed::PackedField;
508518

509519
fn test_vals_a() -> [GoldilocksField; 4] {
510520
[
511-
GoldilocksField([14479013849828404771u64]),
521+
GoldilocksField([18446744069414584320u64]),
512522
GoldilocksField([9087029921428221768u64]),
513523
GoldilocksField([2441288194761790662u64]),
514524
GoldilocksField([5646033492608483824u64]),
515525
]
516526
}
517527
fn test_vals_b() -> [GoldilocksField; 4] {
518528
[
519-
GoldilocksField([17891926589593242302u64]),
529+
GoldilocksField([18446744069414584320u64]),
520530
GoldilocksField([11009798273260028228u64]),
521531
GoldilocksField([2028722748960791447u64]),
522532
GoldilocksField([7929433601095175579u64]),
@@ -530,32 +540,32 @@ mod tests {
530540
let start = Instant::now();
531541
let packed_a = Avx2GoldilocksField::from_slice(&a_arr);
532542
let packed_b = Avx2GoldilocksField::from_slice(&b_arr);
533-
let packed_res = *packed_a + *packed_b;
543+
let packed_res = *packed_a + *packed_b + *packed_a;
534544
let arr_res = packed_res.as_slice();
535545
let avx2_duration = start.elapsed();
536-
// println!("arr_res: {:?}", arr_res);
546+
// log::debug!("arr_res: {:?}", arr_res);
537547

538548
let start = Instant::now();
539-
let expected = a_arr
540-
.iter()
541-
.zip(b_arr)
542-
.map(|(&a, b)| Fr::from_repr(a).unwrap() + Fr::from_repr(b).unwrap());
549+
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| {
550+
Fr::from_repr(a).unwrap() + Fr::from_repr(a).unwrap() + Fr::from_repr(b).unwrap()
551+
});
543552
let expected_values: Vec<Fr> = expected.collect();
544-
// println!("expected values: {:?}", expected_values);
553+
log::debug!("expected values: {:?}", expected_values[0].as_int());
545554
let non_accelerated_duration = start.elapsed();
546555
for (exp, &res) in expected_values.iter().zip(arr_res) {
547556
assert_eq!(res, exp.into_repr());
548557
}
549558

550-
println!("test_add_AVX2_accelerated time: {:?}", avx2_duration);
551-
println!(
559+
log::debug!("test_add_AVX2_accelerated time: {:?}", avx2_duration);
560+
log::debug!(
552561
"test_add_Non_accelerated time: {:?}",
553562
non_accelerated_duration
554563
);
555564
}
556565

557566
#[test]
558567
fn test_mul() {
568+
env_logger::try_init().unwrap_or_default();
559569
let a_arr = test_vals_a();
560570
let b_arr = test_vals_b();
561571
let start = Instant::now();
@@ -564,7 +574,7 @@ mod tests {
564574
let packed_res = packed_a * packed_b;
565575
let arr_res = packed_res.as_slice();
566576
let avx2_duration = start.elapsed();
567-
// println!("arr_res: {:?}", arr_res);
577+
// log::debug!("arr_res: {:?}", arr_res);
568578

569579
let start = Instant::now();
570580
let expected = a_arr
@@ -573,14 +583,14 @@ mod tests {
573583
.map(|(&a, b)| Fr::from_repr(a).unwrap() * Fr::from_repr(b).unwrap());
574584
let expected_values: Vec<Fr> = expected.collect();
575585
let non_accelerated_duration = start.elapsed();
576-
// println!("expected values: {:?}", expected_values);
586+
log::debug!("expected values: {:?}", expected_values);
577587

578588
for (exp, &res) in expected_values.iter().zip(arr_res) {
579589
assert_eq!(res, exp.into_repr());
580590
}
581591

582-
println!("test_mul_AVX2_accelerated time: {:?}", avx2_duration);
583-
println!(
592+
log::debug!("test_mul_AVX2_accelerated time: {:?}", avx2_duration);
593+
log::debug!(
584594
"test_mul_Non_accelerated time: {:?}",
585595
non_accelerated_duration
586596
);
@@ -594,7 +604,7 @@ mod tests {
594604
let packed_res = packed_a / GoldilocksField([7929433601095175579u64]);
595605
let arr_res = packed_res.as_slice();
596606
let avx2_duration = start.elapsed();
597-
// println!("arr_res: {:?}", arr_res);
607+
// log::debug!("arr_res: {:?}", arr_res);
598608

599609
let start = Instant::now();
600610
let expected = a_arr.iter().map(|&a| {
@@ -603,14 +613,14 @@ mod tests {
603613
});
604614
let expected_values: Vec<Fr> = expected.collect();
605615
let non_accelerated_duration = start.elapsed();
606-
// println!("expected values: {:?}", expected_values);
616+
// log::debug!("expected values: {:?}", expected_values);
607617

608618
for (exp, &res) in expected_values.iter().zip(arr_res) {
609619
assert_eq!(res, exp.into_repr());
610620
}
611621

612-
println!("test_div_AVX2_accelerated time: {:?}", avx2_duration);
613-
println!(
622+
log::debug!("test_div_AVX2_accelerated time: {:?}", avx2_duration);
623+
log::debug!(
614624
"test_div_Non_accelerated time: {:?}",
615625
non_accelerated_duration
616626
);
@@ -624,7 +634,7 @@ mod tests {
624634
let packed_res = packed_a.square();
625635
let arr_res = packed_res.as_slice();
626636
let avx2_duration = start.elapsed();
627-
// println!("arr_res: {:?}", arr_res);
637+
// log::debug!("arr_res: {:?}", arr_res);
628638

629639
let start = Instant::now();
630640
let mut expected_values = Vec::new();
@@ -640,12 +650,12 @@ mod tests {
640650
}
641651
}
642652
let non_accelerated_duration = start.elapsed();
643-
// println!("expected values: {:?}", expected_values);
653+
// log::debug!("expected values: {:?}", expected_values);
644654
for (exp, &res) in expected_values.iter().zip(arr_res) {
645655
assert_eq!(res, exp.into_repr());
646656
}
647-
println!("test_square_AVX2_accelerated time: {:?}", avx2_duration);
648-
println!(
657+
log::debug!("test_square_AVX2_accelerated time: {:?}", avx2_duration);
658+
log::debug!(
649659
"test_square_Non_accelerated time: {:?}",
650660
non_accelerated_duration
651661
);
@@ -659,20 +669,20 @@ mod tests {
659669
let packed_res = -packed_a;
660670
let arr_res = packed_res.as_slice();
661671
let avx2_duration = start.elapsed();
662-
// println!("arr_res: {:?}", arr_res);
672+
// log::debug!("arr_res: {:?}", arr_res);
663673

664674
let start = Instant::now();
665675
let expected = a_arr.iter().map(|&a| -Fr::from_repr(a).unwrap());
666676
let expected_values: Vec<Fr> = expected.collect();
667677
let non_accelerated_duration = start.elapsed();
668-
// println!("expected values: {:?}", expected_values);
678+
// log::debug!("expected values: {:?}", expected_values);
669679

670680
for (exp, &res) in expected_values.iter().zip(arr_res) {
671681
assert_eq!(res, exp.into_repr());
672682
}
673683

674-
println!("test_neg_AVX2_accelerated time: {:?}", avx2_duration);
675-
println!(
684+
log::debug!("test_neg_AVX2_accelerated time: {:?}", avx2_duration);
685+
log::debug!(
676686
"test_neg_Non_accelerated time: {:?}",
677687
non_accelerated_duration
678688
);
@@ -688,7 +698,7 @@ mod tests {
688698
let packed_res = packed_a - packed_b;
689699
let arr_res = packed_res.as_slice();
690700
let avx2_duration = start.elapsed();
691-
// println!("arr_res: {:?}", arr_res);
701+
// log::debug!("arr_res: {:?}", arr_res);
692702

693703
let start = Instant::now();
694704
let expected = a_arr
@@ -697,14 +707,14 @@ mod tests {
697707
.map(|(&a, b)| Fr::from_repr(a).unwrap() - Fr::from_repr(b).unwrap());
698708
let expected_values: Vec<Fr> = expected.collect();
699709
let non_accelerated_duration = start.elapsed();
700-
// println!("expected values: {:?}", expected_values);
710+
// log::debug!("expected values: {:?}", expected_values);
701711

702712
for (exp, &res) in expected_values.iter().zip(arr_res) {
703713
assert_eq!(res, exp.into_repr());
704714
}
705715

706-
println!("test_sub_AVX2_accelerated time: {:?}", avx2_duration);
707-
println!(
716+
log::debug!("test_sub_AVX2_accelerated time: {:?}", avx2_duration);
717+
log::debug!(
708718
"test_sub_Non_accelerated time: {:?}",
709719
non_accelerated_duration
710720
);

algebraic/src/arch/x86_64/avx512_field_gl.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//! RUSTFLAGS='-C target-feature=+avx512f,+avx512bw,+avx512cd,+avx512dq,+avx512vl' cargo build --release
66
use crate::ff::*;
77
use crate::field_gl::{Fr, FrRepr as GoldilocksField};
8+
use crate::packed::PackedField;
89
use core::arch::x86_64::*;
910
use core::fmt;
1011
use core::fmt::{Debug, Formatter};
@@ -34,27 +35,36 @@ impl Avx512GoldilocksField {
3435
unsafe { transmute(*self) }
3536
}
3637
#[inline]
37-
pub fn from_slice(slice: &[GoldilocksField]) -> &Self {
38+
pub fn square(&self) -> Avx512GoldilocksField {
39+
Self::new(unsafe { square(self.get()) })
40+
}
41+
}
42+
43+
unsafe impl PackedField for Avx512GoldilocksField {
44+
const WIDTH: usize = 8;
45+
46+
type Scalar = GoldilocksField;
47+
48+
const ZEROS: Self = Self([GoldilocksField([0]); 8]);
49+
const ONES: Self = Self([GoldilocksField([1]); 8]);
50+
#[inline]
51+
fn from_slice(slice: &[GoldilocksField]) -> &Self {
3852
assert_eq!(slice.len(), WIDTH);
3953
unsafe { &*slice.as_ptr().cast() }
4054
}
4155
#[inline]
42-
pub fn from_slice_mut(slice: &mut [GoldilocksField]) -> &mut Self {
56+
fn from_slice_mut(slice: &mut [GoldilocksField]) -> &mut Self {
4357
assert_eq!(slice.len(), WIDTH);
4458
unsafe { &mut *slice.as_mut_ptr().cast() }
4559
}
4660
#[inline]
47-
pub fn as_slice(&self) -> &[GoldilocksField] {
61+
fn as_slice(&self) -> &[GoldilocksField] {
4862
&self.0[..]
4963
}
5064
#[inline]
51-
pub fn as_slice_mut(&mut self) -> &mut [GoldilocksField] {
65+
fn as_slice_mut(&mut self) -> &mut [GoldilocksField] {
5266
&mut self.0[..]
5367
}
54-
#[inline]
55-
pub fn square(&self) -> Avx512GoldilocksField {
56-
Self::new(unsafe { square(self.get()) })
57-
}
5868

5969
#[inline]
6070
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
@@ -114,7 +124,7 @@ impl Debug for Avx512GoldilocksField {
114124
impl Default for Avx512GoldilocksField {
115125
#[inline]
116126
fn default() -> Self {
117-
Self([GoldilocksField::from(0); 8])
127+
Self::ZEROS
118128
}
119129
}
120130

@@ -397,6 +407,7 @@ mod tests {
397407
use super::Avx512GoldilocksField;
398408
use crate::ff::*;
399409
use crate::field_gl::{Fr, FrRepr as GoldilocksField};
410+
use crate::packed::PackedField;
400411
use std::time::Instant;
401412

402413
fn test_vals_a() -> [GoldilocksField; 8] {

0 commit comments

Comments
 (0)