Skip to content

Commit f4fd647

Browse files
committed
better many_facts
1 parent b833f7c commit f4fd647

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

verify/simd/many_facts.test.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,48 @@ void facts_inplace(vector<int> &args) {
2828
}
2929
uint64_t b2x32 = (1ULL << 32) % mod;
3030
uint64_t fact = 1;
31-
for(uint64_t b = 0; b <= limit; b += block) {
32-
u64x4 cur = {b, b + block / 4, b + block / 2, b + 3 * block / 4};
33-
static array<u64x4, block / 4> prods;
34-
prods[0] = u64x4{cur[0] + !b, cur[1], cur[2], cur[3]};
35-
cur = cur * b2x32 % mod;
36-
for(int i = 1; i < block / 4; i++) {
37-
cur += b2x32;
38-
cur = cur >= mod ? cur - mod : cur;
39-
prods[i] = montgomery_mul(prods[i - 1], cur, mod4, imod4);
40-
}
41-
for(auto i: args_per_block[b / block]) {
42-
size_t x = args[i];
43-
if(x >= mod / 2) {
44-
x = mod - x - 1;
45-
}
46-
x -= b;
47-
auto pre_blocks = x / (block / 4);
48-
auto in_block = x % (block / 4);
49-
auto ans = fact * prods[in_block][pre_blocks] % mod;
50-
for(size_t z = 0; z < pre_blocks; z++) {
51-
ans = ans * prods.back()[z] % mod;
31+
for(uint64_t b = 0; b <= limit; b += 4 * block) {
32+
u64x4 cur[4];
33+
static array<u64x4, block / 4> prods[4];
34+
for(int z = 0; z < 4; z++) {
35+
for(int j = 0; j < 4; j++) {
36+
cur[z][j] = b + z * block + j * block / 4;
37+
prods[z][0][j] = cur[z][j] + !(b || z || j);
38+
cur[z][j] = cur[z][j] * b2x32 % mod;
5239
}
53-
if(args[i] >= mod / 2) {
54-
ans = math::bpow(ans, mod - 2, 1ULL, [](auto a, auto b){return a * b % mod;});
55-
args[i] = int(x % 2 ? ans : mod - ans);
56-
} else {
57-
args[i] = int(ans);
40+
}
41+
for(int i = 1; i < block / 4; i++) {
42+
for(int z = 0; z < 4; z++) {
43+
cur[z] += b2x32;
44+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
45+
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod4, imod4);
5846
}
5947
}
60-
args_per_block[b / block].clear();
6148
for(int z = 0; z < 4; z++) {
62-
fact = fact * prods.back()[z] % mod;
49+
uint64_t bl = b + z * block;
50+
for(auto i: args_per_block[bl / block]) {
51+
size_t x = args[i];
52+
if(x >= mod / 2) {
53+
x = mod - x - 1;
54+
}
55+
x -= bl;
56+
auto pre_blocks = x / (block / 4);
57+
auto in_block = x % (block / 4);
58+
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
59+
for(size_t j = 0; j < pre_blocks; j++) {
60+
ans = ans * prods[z].back()[j] % mod;
61+
}
62+
if(args[i] >= mod / 2) {
63+
ans = math::bpow(ans, mod - 2, 1ULL, [](auto a, auto b){return a * b % mod;});
64+
args[i] = int(x % 2 ? ans : mod - ans);
65+
} else {
66+
args[i] = int(ans);
67+
}
68+
}
69+
args_per_block[bl / block].clear();
70+
for(int j = 0; j < 4; j++) {
71+
fact = fact * prods[z].back()[j] % mod;
72+
}
6373
}
6474
}
6575
}

0 commit comments

Comments
 (0)