Skip to content

Commit 55b0de6

Browse files
authored
rolling prod more accurate (#7351)
* zeros and num stability * frollprod adaptive fast redirect to exact * more tests * codecov * correct expected ans * Revert "zeros and num stability" This reverts commit 4074e12. * handle zeros in frollprod fast
1 parent c5e8152 commit 55b0de6

File tree

4 files changed

+152
-138
lines changed

4 files changed

+152
-138
lines changed

inst/tests/froll.Rraw

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,8 +1202,8 @@ test(6000.931, frollprod(1:3, 2), c(NA, 2, 6), output="frollprodFast: running fo
12021202
test(6000.932, frollprod(1:3, 2, align="left"), c(2, 6, NA), output="frollfun: align")
12031203
test(6000.933, frollprod(c(1,2,NA), 2), c(NA, 2, NA), output="non-finite values are present in input, re-running with extra care for NFs")
12041204
test(6000.934, frollprod(c(NA,2,3), 2), c(NA, NA, 6), output="non-finite values are present in input, skip non-finite inaware attempt and run with extra care for NFs straighaway")
1205-
test(6000.935, frollprod(1:3, c(2,2,2), adaptive=TRUE), c(NA, 2, 6), output="frolladaptiveprodFast: running for input length")
1206-
test(6000.936, frollprod(c(NA,2,3), c(2,2,2), adaptive=TRUE), c(NA, NA, 6), output="non-finite values are present in input, re-running with extra care for NFs")
1205+
test(6000.935, frollprod(1:3, c(2,2,2), adaptive=TRUE), c(NA, 2, 6), output="algo 0 not implemented, fall back to 1")
1206+
test(6000.936, frollprod(c(NA,2,3), c(2,2,2), adaptive=TRUE), c(NA, NA, 6), output="non-finite values are present in input, na.rm=FALSE and algo='exact' propagates NFs properply, no need to re-run")
12071207
options(datatable.verbose=FALSE)
12081208
# floating point overflow
12091209
test(6000.941, frollprod(c(1e100, 1e100, 1e100, 1e100, 1e100), 5), c(NA,NA,NA,NA,Inf))
@@ -1222,6 +1222,15 @@ test(6000.953, frollprod(c(1e100, 1e100, 1e100, 1e100, 1e100), rep(5, 5), algo="
12221222
test(6000.954, frollprod(c(1e100, 1e100, 1e100, 1e100, 1e100), rep(4, 5), algo="exact", adaptive=TRUE), c(NA,NA,NA,Inf,Inf))
12231223
test(6000.955, frollprod(c(1e100, 1e100, 1e100, 1e100, -1e100), rep(5, 5), algo="exact", adaptive=TRUE), c(NA,NA,NA,NA,-Inf))
12241224
test(6000.956, frollprod(c(1e100, 1e100, 1e100, 1e100, -1e100), rep(4, 5), algo="exact", adaptive=TRUE), c(NA,NA,NA,Inf,-Inf))
1225+
# rolling product and numerical stability #7349
1226+
test(6000.9601, frollprod(c(2,2,0,2,2), 2), c(NA,4,0,0,4))
1227+
test(6000.9602, frollprod(c(2,2,0,-2,2), 2), c(NA,4,0,0,-4))
1228+
test(6000.9603, frollprod(c(2,2,0,-2,-2), 2), c(NA,4,0,0,4))
1229+
test(6000.9604, frollprod(c(2,2,0,Inf,2), 2), c(NA,4,0,NaN,Inf))
1230+
test(6000.9605, frollprod(c(2,2,0,-Inf,2), 2), c(NA,4,0,NaN,-Inf))
1231+
test(6000.9606, frollprod(c(2,2,0,-Inf,-2), 2), c(NA,4,0,NaN,Inf))
1232+
test(6000.9607, frollprod(c(0,2,2,2,2), 2), c(NA,0,4,4,4))
1233+
test(6000.9608, frollprod(c(2,0,2,2,2), 2), c(NA,0,0,4,4))
12251234

12261235
# n==0, k==0, k[i]==0
12271236
test(6001.111, frollmean(1:3, 0), c(NaN,NaN,NaN), options=c("datatable.verbose"=TRUE), output="window width of size 0")
@@ -2242,7 +2251,7 @@ rollfun = function(x, n, FUN, fill=NA_real_, na.rm=FALSE, nf.rm=FALSE, partial=F
22422251
ans
22432252
}
22442253
base_compare = function(x, n, funs=c("mean","sum","max","min","prod","median"), algos=c("fast","exact")) {
2245-
num.step = 0.001
2254+
num.step = 0.0001
22462255
for (fun in funs) {
22472256
for (na.rm in c(FALSE, TRUE)) {
22482257
for (fill in c(NA_real_, 0)) {
@@ -2276,6 +2285,39 @@ base_compare = function(x, n, funs=c("mean","sum","max","min","prod","median"),
22762285
}
22772286
}
22782287
}
2288+
num = 7000.0
2289+
x = rnorm(1e3); n = 50
2290+
base_compare(x, n)
2291+
x = rnorm(1e3+1); n = 50 ## uneven len
2292+
base_compare(x, n)
2293+
x = rnorm(1e3); n = 51 ## uneven window
2294+
base_compare(x, n)
2295+
x = rnorm(1e3+1); n = 51
2296+
base_compare(x, n)
2297+
x = sort(rnorm(1e3)); n = 50 ## inc
2298+
base_compare(x, n)
2299+
x = sort(rnorm(1e3+1)); n = 50
2300+
base_compare(x, n)
2301+
x = sort(rnorm(1e3)); n = 51
2302+
base_compare(x, n)
2303+
x = sort(rnorm(1e3+1)); n = 51
2304+
base_compare(x, n)
2305+
x = rev(sort(rnorm(1e3))); n = 50 ## desc
2306+
base_compare(x, n)
2307+
x = rev(sort(rnorm(1e3+1))); n = 50
2308+
base_compare(x, n)
2309+
x = rev(sort(rnorm(1e3))); n = 51
2310+
base_compare(x, n)
2311+
x = rev(sort(rnorm(1e3+1))); n = 51
2312+
base_compare(x, n)
2313+
x = rep(rnorm(1), 1e3); n = 50 ## const
2314+
base_compare(x, n)
2315+
x = rep(rnorm(1), 1e3+1); n = 50
2316+
base_compare(x, n)
2317+
x = rep(rnorm(1), 1e3); n = 51
2318+
base_compare(x, n)
2319+
x = rep(rnorm(1), 1e3+1); n = 51
2320+
base_compare(x, n)
22792321
num = 7100.0
22802322
## random NA non-finite
22812323
x = makeNA(rnorm(1e3), nf=TRUE); n = 50
@@ -2433,6 +2475,39 @@ afun_compare = function(x, n, funs=c("mean","sum","max","min","prod","median"),
24332475
}
24342476
}
24352477
num = 7300.0
2478+
x = rnorm(1e3); n = sample(50, length(x), TRUE)
2479+
afun_compare(x, n)
2480+
x = rnorm(1e3+1); n = sample(50, length(x), TRUE) ## uneven len
2481+
afun_compare(x, n)
2482+
x = rnorm(1e3); n = sample(51, length(x), TRUE) ## uneven window
2483+
afun_compare(x, n)
2484+
x = rnorm(1e3+1); n = sample(51, length(x), TRUE)
2485+
afun_compare(x, n)
2486+
x = sort(rnorm(1e3)); n = sample(50, length(x), TRUE) ## inc
2487+
afun_compare(x, n)
2488+
x = sort(rnorm(1e3+1)); n = sample(50, length(x), TRUE)
2489+
afun_compare(x, n)
2490+
x = sort(rnorm(1e3)); n = sample(51, length(x), TRUE)
2491+
afun_compare(x, n)
2492+
x = sort(rnorm(1e3+1)); n = sample(51, length(x), TRUE)
2493+
afun_compare(x, n)
2494+
x = rev(sort(rnorm(1e3))); n = sample(50, length(x), TRUE) ## desc
2495+
afun_compare(x, n)
2496+
x = rev(sort(rnorm(1e3+1))); n = sample(50, length(x), TRUE)
2497+
afun_compare(x, n)
2498+
x = rev(sort(rnorm(1e3))); n = sample(51, length(x), TRUE)
2499+
afun_compare(x, n)
2500+
x = rev(sort(rnorm(1e3+1))); n = sample(51, length(x), TRUE)
2501+
afun_compare(x, n)
2502+
x = rep(rnorm(1), 1e3); n = sample(50, length(x), TRUE) ## const
2503+
afun_compare(x, n)
2504+
x = rep(rnorm(1), 1e3+1); n = sample(50, length(x), TRUE)
2505+
afun_compare(x, n)
2506+
x = rep(rnorm(1), 1e3); n = sample(51, length(x), TRUE)
2507+
afun_compare(x, n)
2508+
x = rep(rnorm(1), 1e3+1); n = sample(51, length(x), TRUE)
2509+
afun_compare(x, n)
2510+
num = 7400.0
24362511
#### no NA
24372512
x = rnorm(1e3); n = sample(50, length(x), TRUE) # x even, n even
24382513
afun_compare(x, n)

src/data.table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ void frolladaptivesumExact(const double *x, uint64_t nx, ans_t *ans, const int *
258258
void frolladaptivemaxExact(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose);
259259
//void frolladaptiveminFast(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose); // does not exists as of now
260260
void frolladaptiveminExact(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose);
261-
void frolladaptiveprodFast(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose);
261+
//void frolladaptiveprodFast(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose); // does not exists as of now
262262
void frolladaptiveprodExact(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose);
263263
//void frolladaptivemedianFast(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose); // does not exists as of now
264264
void frolladaptivemedianExact(const double *x, uint64_t nx, ans_t *ans, const int *k, double fill, bool narm, int hasnf, bool verbose);

src/froll.c

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,11 @@ void frollminExact(const double *x, uint64_t nx, ans_t *ans, int k, double fill,
895895
#undef PROD_WINDOW_STEP_FRONT
896896
#define PROD_WINDOW_STEP_FRONT \
897897
if (R_FINITE(x[i])) { \
898-
w *= x[i]; \
898+
if (x[i] == 0.0) { \
899+
zerc++; \
900+
} else { \
901+
w *= x[i]; \
902+
} \
899903
} else if (ISNAN(x[i])) { \
900904
nc++; \
901905
} else if (x[i]==R_PosInf) { \
@@ -906,7 +910,11 @@ void frollminExact(const double *x, uint64_t nx, ans_t *ans, int k, double fill,
906910
#undef PROD_WINDOW_STEP_BACK
907911
#define PROD_WINDOW_STEP_BACK \
908912
if (R_FINITE(x[i-k])) { \
909-
w /= x[i-k]; \
913+
if (x[i-k] == 0.0) { \
914+
zerc--; \
915+
} else { \
916+
w /= x[i-k]; \
917+
} \
910918
} else if (ISNAN(x[i-k])) { \
911919
nc--; \
912920
} else if (x[i-k]==R_PosInf) { \
@@ -918,18 +926,34 @@ void frollminExact(const double *x, uint64_t nx, ans_t *ans, int k, double fill,
918926
#define PROD_WINDOW_STEP_VALUE \
919927
if (nc == 0) { \
920928
if (pinf == 0 && ninf == 0) { \
921-
ans->dbl_v[i] = (double) w; \
929+
if (zerc) { \
930+
ans->dbl_v[i] = 0.0; \
931+
} else { \
932+
ans->dbl_v[i] = (double) w; \
933+
} \
922934
} else { \
923-
ans->dbl_v[i] = (ninf+(w<0))%2 ? R_NegInf : R_PosInf; \
935+
if (zerc) { \
936+
ans->dbl_v[i] = R_NaN; \
937+
} else { \
938+
ans->dbl_v[i] = (ninf+(w<0))%2 ? R_NegInf : R_PosInf; \
939+
} \
924940
} \
925941
} else if (nc == k) { \
926942
ans->dbl_v[i] = narm ? 1.0 : NA_REAL; \
927943
} else { \
928944
if (narm) { \
929945
if (pinf == 0 && ninf == 0) { \
930-
ans->dbl_v[i] = (double) w; \
946+
if (zerc) { \
947+
ans->dbl_v[i] = 0.0; \
948+
} else { \
949+
ans->dbl_v[i] = (double) w; \
950+
} \
931951
} else { \
932-
ans->dbl_v[i] = (ninf+(w<0))%2 ? R_NegInf : R_PosInf; \
952+
if (zerc) { \
953+
ans->dbl_v[i] = R_NaN; \
954+
} else { \
955+
ans->dbl_v[i] = (ninf+(w<0))%2 ? R_NegInf : R_PosInf; \
956+
} \
933957
} \
934958
} else { \
935959
ans->dbl_v[i] = NA_REAL; \
@@ -951,34 +975,59 @@ void frollprodFast(const double *x, uint64_t nx, ans_t *ans, int k, double fill,
951975
return;
952976
}
953977
long double w = 1.0;
978+
int zerc = 0;
954979
bool truehasnf = hasnf>0;
955980
if (!truehasnf) {
956981
int i;
957982
for (i=0; i<k-1; i++) { // #loop_counter_not_local_scope_ok
958-
w *= x[i];
983+
if (x[i] == 0.0) {
984+
zerc++;
985+
} else {
986+
w *= x[i];
987+
}
959988
ans->dbl_v[i] = fill;
960989
}
961-
w *= x[i];
962-
ans->dbl_v[i] = (double) w;
990+
if (x[i] == 0.0) {
991+
zerc++;
992+
} else {
993+
w *= x[i];
994+
}
995+
if (zerc) {
996+
ans->dbl_v[i] = 0.0;
997+
} else {
998+
ans->dbl_v[i] = (double) w;
999+
}
9631000
if (R_FINITE((double) w)) {
9641001
for (uint64_t i=k; i<nx; i++) {
965-
w /= x[i-k];
966-
w *= x[i];
967-
ans->dbl_v[i] = (double) w;
1002+
if (x[i-k] == 0.0) {
1003+
zerc--;
1004+
} else {
1005+
w /= x[i-k];
1006+
}
1007+
if (x[i] == 0.0) {
1008+
zerc++;
1009+
} else {
1010+
w *= x[i];
1011+
}
1012+
if (zerc) {
1013+
ans->dbl_v[i] = 0.0;
1014+
} else {
1015+
ans->dbl_v[i] = (double) w;
1016+
}
9681017
}
9691018
if (!R_FINITE((double) w)) {
9701019
if (hasnf==-1)
9711020
ansSetMsg(ans, 2, "%s: has.nf=FALSE used but non-finite values are present in input, use default has.nf=NA to avoid this warning", __func__);
9721021
if (verbose)
9731022
ansSetMsg(ans, 0, "%s: non-finite values are present in input, re-running with extra care for NFs\n", __func__);
974-
w = 1.0; truehasnf = true;
1023+
w = 1.0; zerc = 0; truehasnf = true;
9751024
}
9761025
} else {
9771026
if (hasnf==-1)
9781027
ansSetMsg(ans, 2, "%s: has.nf=FALSE used but non-finite values are present in input, use default has.nf=NA to avoid this warning", __func__);
9791028
if (verbose)
9801029
ansSetMsg(ans, 0, "%s: non-finite values are present in input, skip non-finite inaware attempt and run with extra care for NFs straighaway\n", __func__);
981-
w = 1.0; truehasnf = true;
1030+
w = 1.0; zerc = 0; truehasnf = true;
9821031
}
9831032
}
9841033
if (truehasnf) {

0 commit comments

Comments
 (0)