Skip to content

Commit 51ea8ec

Browse files
committed
Use const generics for PPartMultiplier mod4 option
This only uses min_const_generics which will be stable in Rust 1.51. This avoids the performance penalty in "normal" code.
1 parent 8adebfd commit 51ea8ec

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

ext/crates/algebra/src/algebra/milnor_algebra.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ impl MilnorAlgebra {
676676
if self.generic {
677677
let m1f = self.multiply_qpart(m1, m2.q_part);
678678
for (cc, basis) in m1f {
679-
let mut multiplier = PPartMultiplier::new(self.prime(), &(basis.p_part), &(m2.p_part), false);
679+
let mut multiplier = PPartMultiplier::<false>::new(self.prime(), &(basis.p_part), &(m2.p_part));
680680
let mut new = MilnorBasisElement {
681681
degree : target_dim,
682682
q_part : basis.q_part,
@@ -688,7 +688,7 @@ impl MilnorAlgebra {
688688
}
689689
}
690690
} else {
691-
let mut multiplier = PPartMultiplier::new(self.prime(), &(m1.p_part), &(m2.p_part), false);
691+
let mut multiplier = PPartMultiplier::<false>::new(self.prime(), &(m1.p_part), &(m2.p_part));
692692
let mut new = MilnorBasisElement {
693693
degree: target_dim,
694694
q_part: 0,
@@ -732,11 +732,10 @@ impl std::ops::IndexMut<usize> for Matrix2D {
732732
}
733733

734734
#[allow(non_snake_case)]
735-
pub struct PPartMultiplier<'a> {
735+
pub struct PPartMultiplier<'a, const MOD4: bool> {
736736
p : ValidPrime,
737737
M : Matrix2D,
738738
r : &'a PPart,
739-
mod_4: bool,
740739
rows : usize,
741740
cols : usize,
742741
diag_num : usize,
@@ -745,14 +744,14 @@ pub struct PPartMultiplier<'a> {
745744
}
746745

747746
#[allow(non_snake_case)]
748-
impl<'a> PPartMultiplier<'a> {
747+
impl<'a, const MOD4: bool> PPartMultiplier<'a, MOD4> {
749748
fn prime(&self) -> ValidPrime {
750749
self.p
751750
}
752751

753752
#[allow(clippy::ptr_arg)]
754-
pub fn new (p : ValidPrime, r : &'a PPart, s : &'a PPart, mod_4: bool) -> PPartMultiplier<'a> {
755-
if mod_4 {
753+
pub fn new (p : ValidPrime, r : &'a PPart, s : &'a PPart) -> Self {
754+
if MOD4 {
756755
assert_eq!(*p, 2);
757756
}
758757
let rows = r.len() + 1;
@@ -767,7 +766,7 @@ impl<'a> PPartMultiplier<'a> {
767766
}
768767
M[0][1..cols].clone_from_slice(&s[0..(cols - 1)]);
769768

770-
PPartMultiplier { p, M, r, rows, cols, diag_num, diagonal, init : true, mod_4 }
769+
PPartMultiplier { p, M, r, rows, cols, diag_num, diagonal, init : true }
771770
}
772771

773772
fn update(&mut self) -> bool {
@@ -817,7 +816,7 @@ impl<'a> PPartMultiplier<'a> {
817816
if self.init {
818817
self.init = false;
819818
for i in 1 .. std::cmp::min(self.cols, self.rows) {
820-
if self.mod_4 {
819+
if MOD4 {
821820
coef *= fp::prime::binomial4(self.M[i][0] + self.M[0][i], self.M[0][i]);
822821
} else {
823822
coef *= fp::prime::binomial(self.prime(), (self.M[i][0] + self.M[0][i]) as i32, self.M[0][i] as i32);
@@ -852,7 +851,7 @@ impl<'a> PPartMultiplier<'a> {
852851
if sum == 0 {
853852
continue;
854853
}
855-
if self.mod_4 {
854+
if MOD4 {
856855
if coef == 2 {
857856
coef *= fp::prime::multinomial2(&self.diagonal);
858857
} else {

0 commit comments

Comments
 (0)