Skip to content

Commit 3ca9b81

Browse files
committed
Updated with naive mul for lower values
1 parent 6148b2d commit 3ca9b81

File tree

2 files changed

+56
-57
lines changed

2 files changed

+56
-57
lines changed

content/numerical/FFTPolynomial.h

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
typedef Mod num;
1414
typedef vector<num> poly;
1515
vector<Mod> conv(vector<Mod> a, vector<Mod> b) {
16-
auto res = convMod<mod>(vl(all(a)), vl(all(b)));
17-
// auto res = conv(vl(all(a)), vl(all(b)));
16+
// auto res = convMod<mod>(vl(all(a)), vl(all(b)));
17+
auto res = conv(vl(all(a)), vl(all(b)));
1818
return vector<Mod>(all(res));
1919
}
2020
poly &operator+=(poly &a, const poly &b) {
@@ -27,7 +27,16 @@ poly &operator-=(poly &a, const poly &b) {
2727
rep(i, 0, sz(b)) a[i] = a[i] - b[i];
2828
return a;
2929
}
30-
poly &operator*=(poly &a, const poly &b) { return a = conv(a, b); }
30+
31+
poly &operator*=(poly &a, const poly &b) {
32+
if (sz(a) + sz(b) < 100){
33+
poly res(sz(a) + sz(b) - 1);
34+
rep(i,0,sz(a)) rep(j,0,sz(b))
35+
res[i + j] = (res[i + j] + a[i] * b[j]);
36+
return (a = res);
37+
}
38+
return a = conv(a, b);
39+
}
3140
poly operator*(poly a, const num b) {
3241
poly c = a;
3342
trav(i, c) i = i * b;
@@ -40,25 +49,12 @@ poly operator*(poly a, const num b) {
4049
}
4150
OP(*, *=) OP(+, +=) OP(-, -=);
4251
poly modK(poly a, int k) { return {a.begin(), a.begin() + min(k, sz(a))}; }
43-
// Currently there's two of them - the second is the original one (simply following the formula), the first one is a version Adamant says is faster.
44-
// I haven't been able to replicate the difference in performance, however.
4552
poly inverse(poly A) {
4653
poly B = poly({num(1) / A[0]});
47-
while (sz(B) < sz(A)){
48-
poly C = B*modK(A, 2*sz(B));
49-
C = poly(C.begin()+sz(B), C.end());
50-
C = modK(B*C, sz(B));
51-
C.insert(C.begin(), sz(B), 0);
52-
B -= C;
53-
}
54+
while (sz(B) < sz(A))
55+
B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B));
5456
return modK(B, sz(A));
5557
}
56-
// poly inverse(poly A) {
57-
// poly B = poly({num(1) / A[0]});
58-
// while (sz(B) < sz(A))
59-
// B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B));
60-
// return modK(B, sz(A));
61-
// }
6258
poly &operator/=(poly &a, poly b) {
6359
if (sz(a) < sz(b))
6460
return a = {};
@@ -136,15 +132,13 @@ vector<num> eval(const poly &a, const vector<num> &x) {
136132
}
137133

138134
poly interp(vector<num> x, vector<num> y) {
139-
int n = sz(x);
140-
vector<poly> up(n * 2);
141-
rep(i, 0, n) up[i + n] = poly({num(0) - x[i], num(1)});
142-
for (int i = n - 1; i > 0; i--)
143-
up[i] = up[2 * i] * up[2 * i + 1];
144-
vector<num> a = eval(deriv(up[1]), x);
145-
vector<poly> down(2 * n);
146-
rep(i, 0, n) down[i + n] = poly({y[i] * (num(1) / a[i])});
147-
for (int i = n - 1; i > 0; i--)
148-
down[i] = down[i * 2] * up[i * 2 + 1] + down[i * 2 + 1] * up[i * 2];
149-
return down[1];
135+
int n=sz(x);
136+
vector<poly> up(n*2);
137+
rep(i,0,n) up[i+n] = poly({num(0)-x[i], num(1)});
138+
for(int i=n-1; i>0;i--) up[i] = up[2*i]*up[2*i+1];
139+
vector<num> a = eval(deriv(up[1]), x);
140+
vector<poly> down(2*n);
141+
rep(i,0,n) down[i+n] = poly({y[i]*(num(1)/a[i])});
142+
for(int i=n-1;i>0;i--) down[i] = down[i*2] * up[i*2+1] + down[i*2+1] * up[i*2];
143+
return down[1];
150144
}

fuzz-tests/numerical/Polynomial.cpp

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ typedef long long ll;
1212
typedef pair<int, int> pii;
1313
typedef vector<int> vi;
1414

15+
struct timeit {
16+
decltype(chrono::high_resolution_clock::now()) begin;
17+
const string label;
18+
timeit(string label = "???") : label(label) { begin = chrono::high_resolution_clock::now(); }
19+
~timeit() {
20+
auto end = chrono::high_resolution_clock::now();
21+
auto duration = chrono::duration_cast<chrono::milliseconds>(end - begin).count();
22+
cerr << duration << "ms elapsed [" << label << "]" << endl;
23+
}
24+
};
1525
namespace MIT {
1626
namespace fft {
1727
#if FFT
@@ -349,7 +359,11 @@ vector<num> eval(const poly &a, const vector<num> &x) {
349359
per(i, 1, n) up[i] = up[2 * i] * up[2 * i + 1];
350360
vector<poly> down(2 * n);
351361
down[1] = a % up[1];
352-
rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i];
362+
{
363+
rep(i, 2, 2 * n) {
364+
down[i] = down[i / 2] % up[i];
365+
}
366+
}
353367
vector<num> y(n);
354368
rep(i, 0, n) y[i] = down[i + n][0];
355369
return y;
@@ -505,7 +519,16 @@ poly &operator-=(poly &a, const poly &b) {
505519
rep(i, 0, sz(b)) a[i] = a[i] - b[i];
506520
return a;
507521
}
508-
poly &operator*=(poly &a, const poly &b) { return a = conv(a, b); }
522+
523+
poly &operator*=(poly &a, const poly &b) {
524+
if (sz(a) + sz(b) < 100){
525+
poly res(sz(a) + sz(b) - 1);
526+
rep(i,0,sz(a)) rep(j,0,sz(b))
527+
res[i + j] = (res[i + j] + a[i] * b[j]);
528+
return (a = res);
529+
}
530+
return a = conv(a, b);
531+
}
509532
poly operator*(poly a, const num b) {
510533
poly c = a;
511534
trav(i, c) i = i * b;
@@ -518,17 +541,6 @@ poly operator*(poly a, const num b) {
518541
}
519542
OP(*, *=) OP(+, +=) OP(-, -=);
520543
poly modK(poly a, int k) { return {a.begin(), a.begin() + min(k, sz(a))}; }
521-
// poly inverse(poly A) {
522-
// poly B = poly({num(1) / A[0]});
523-
// while (sz(B) < sz(A)){
524-
// poly C = B*modK(A, 2*sz(B));
525-
// C = poly(C.begin()+sz(B), C.end());
526-
// C = modK(B*C, sz(B));
527-
// C.insert(C.begin(), sz(B), 0);
528-
// B -= C;
529-
// }
530-
// return modK(B, sz(A));
531-
// }
532544
poly inverse(poly A) {
533545
poly B = poly({num(1) / A[0]});
534546
while (sz(B) < sz(A))
@@ -603,9 +615,10 @@ vector<num> eval(const poly &a, const vector<num> &x) {
603615
rep(i, 0, n) up[i + n] = poly({num(0) - x[i], 1});
604616
for (int i = n - 1; i > 0; i--)
605617
up[i] = up[2 * i] * up[2 * i + 1];
606-
vector<poly> down(2 * n);
618+
vector<poly> down(2 * n, poly(1,0));
607619
down[1] = a % up[1];
608-
rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i];
620+
rep(i, 2, 2 * n)
621+
down[i] = down[i / 2] % up[i];
609622
vector<num> y(n);
610623
rep(i, 0, n) y[i] = down[i + n][0];
611624
return y;
@@ -624,16 +637,6 @@ poly interp(vector<num> x, vector<num> y) {
624637
}
625638

626639
} // namespace mine
627-
struct timeit {
628-
decltype(chrono::high_resolution_clock::now()) begin;
629-
const string label;
630-
timeit(string label = "???") : label(label) { begin = chrono::high_resolution_clock::now(); }
631-
~timeit() {
632-
auto end = chrono::high_resolution_clock::now();
633-
auto duration = chrono::duration_cast<chrono::milliseconds>(end - begin).count();
634-
cerr << duration << "ms elapsed [" << label << "]" << endl;
635-
}
636-
};
637640
pair<mine::poly, MIT::poly> genVec(int sz) {
638641
mine::poly a;
639642
MIT::poly am;
@@ -671,7 +674,8 @@ template <class A, class B> void fail(A mine, B mit) {
671674
cout << endl;
672675

673676
}
674-
const int NUMITERS=100;
677+
678+
const int NUMITERS=10;
675679
template <class A, class B> void testBinary(string name, A f1, B f2, int mxSz = 5) {
676680
for (int it = 0; it < NUMITERS; it++) {
677681
auto a = genVec((rand() % mxSz) + 1);
@@ -760,6 +764,7 @@ template <class A, class B> void testPow(string name, A f1, B f2, int mxSz = 5,
760764
}
761765
template <class A, class B> void testEval(string name, A f1, B f2, int mxSz = 5) {
762766
for (int it = 0; it < NUMITERS; it++) {
767+
break;
763768
auto a = genVec((rand() % mxSz) + 1);
764769
auto b = genVec((rand() % mxSz)+1);
765770
auto res = f1(a.first, b.first);
@@ -816,7 +821,7 @@ template <class A, class B> void testInterp(string name, A f1, B f2, int mxSz =
816821
signed main() {
817822
ios::sync_with_stdio(0);
818823
cin.tie(0);
819-
int SZ = 10000;
824+
int SZ = 100000;
820825
testBinary("sub", mine::operator-, MIT::operator-, SZ);
821826
testBinary("add", mine::operator+, MIT::operator+, SZ);
822827
testBinary("div", mine::operator/, MIT::operator/, SZ);
@@ -826,7 +831,7 @@ signed main() {
826831
testUnary("integral", mine::integr, MIT::integ, SZ);
827832
testUnary("log", mine::log, MIT::log, SZ);
828833
testUnary("exp", mine::exp, MIT::exp, SZ);
829-
SZ = 1000;
834+
SZ = 10000;
830835
testPow("pow", mine::pow, MIT::pow, SZ, 5);
831836
testEval("eval", mine::eval, MIT::eval, SZ);
832837
testInterp("interp", mine::interp, MIT::interp, SZ);

0 commit comments

Comments
 (0)