Skip to content

Commit dc61819

Browse files
committed
add nCr, use std::vector instead of static arrays, use std::complex in fft
1 parent 4d117b8 commit dc61819

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

cp-algo/algebra/common.hpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace cp_algo::algebra {
3232

3333
template<typename T>
3434
T fact(int n) {
35-
static T F[maxn];
35+
static std::vector<T> F(maxn);
3636
static bool init = false;
3737
if(!init) {
3838
F[0] = T(1);
@@ -46,7 +46,7 @@ namespace cp_algo::algebra {
4646

4747
template<typename T>
4848
T rfact(int n) {
49-
static T F[maxn];
49+
static std::vector<T> F(maxn);
5050
static bool init = false;
5151
if(!init) {
5252
F[maxn - 1] = T(1) / fact<T>(maxn - 1);
@@ -60,7 +60,7 @@ namespace cp_algo::algebra {
6060

6161
template<typename T>
6262
T small_inv(int n) {
63-
static T F[maxn];
63+
static std::vector<T> F(maxn);
6464
static bool init = false;
6565
if(!init) {
6666
for(int i = 1; i < maxn; i++) {
@@ -70,5 +70,14 @@ namespace cp_algo::algebra {
7070
}
7171
return F[n];
7272
}
73+
74+
template<typename T>
75+
T nCr(int n, int r) {
76+
if(r < 0 || r > n) {
77+
return T(0);
78+
} else {
79+
return fact<T>(n) * rfact<T>(r) * rfact<T>(n-r);
80+
}
81+
}
7382
}
7483
#endif // CP_ALGO_ALGEBRA_COMMON_HPP

cp-algo/algebra/fft.hpp

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,25 @@
33
#include "common.hpp"
44
#include "modular.hpp"
55
#include <algorithm>
6+
#include <complex>
67
#include <cassert>
78
#include <vector>
89
namespace cp_algo::algebra::fft {
910
using ftype = double;
10-
struct point {
11-
ftype x, y;
12-
13-
ftype real() {return x;}
14-
ftype imag() {return y;}
15-
16-
point(): x(0), y(0){}
17-
point(ftype x, ftype y = 0): x(x), y(y){}
18-
19-
static point polar(ftype rho, ftype ang) {
20-
return point{rho * cos(ang), rho * sin(ang)};
21-
}
22-
23-
point conj() const {
24-
return {x, -y};
25-
}
26-
27-
point operator +=(const point &t) {x += t.x, y += t.y; return *this;}
28-
point operator +(const point &t) const {return point(*this) += t;}
29-
point operator -(const point &t) const {return {x - t.x, y - t.y};}
30-
point operator *(const point &t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
31-
};
11+
using point = std::complex<ftype>;
3212

33-
point w[maxn]; // w[2^n + k] = exp(pi * k / (2^n))
34-
int bitr[maxn];// b[2^n + k] = bitreverse(k)
13+
std::vector<point> w; // w[2^n + k] = exp(pi * k / (2^n))
14+
std::vector<int> bitr;// b[2^n + k] = bitreverse(k)
3515
const ftype pi = acos(-1);
3616
bool initiated = 0;
3717
void init() {
3818
if(!initiated) {
19+
w.resize(maxn);
20+
bitr.resize(maxn);
3921
for(int i = 1; i < maxn; i *= 2) {
4022
int ti = i / 2;
4123
for(int j = 0; j < i; j++) {
42-
w[i + j] = point::polar(ftype(1), pi * j / i);
24+
w[i + j] = std::polar(ftype(1), pi * j / i);
4325
if(ti) {
4426
bitr[i + j] = 2 * bitr[ti + j % ti] + (j >= ti);
4527
}
@@ -113,8 +95,8 @@ namespace cp_algo::algebra::fft {
11395
}
11496
std::vector<point> C(n), D(n);
11597
for(size_t i = 0; i < n; i++) {
116-
C[i] = A[i] * (B[i] + B[(n - i) % n].conj());
117-
D[i] = A[i] * (B[i] - B[(n - i) % n].conj());
98+
C[i] = A[i] * (B[i] + conj(B[(n - i) % n]));
99+
D[i] = A[i] * (B[i] - conj(B[(n - i) % n]));
118100
}
119101
fft(C, n);
120102
fft(D, n);

0 commit comments

Comments
 (0)