Skip to content

Commit 47838f4

Browse files
committed
Initial fix (w/ prints)
1 parent a51b2dc commit 47838f4

File tree

1 file changed

+89
-53
lines changed

1 file changed

+89
-53
lines changed

pandas/_libs/window/aggregations.pyx

Lines changed: 89 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
# cython: boundscheck=False, wraparound=False, cdivision=True
2-
2+
from libc.stdio cimport printf
33
from libc.math cimport (
44
round,
55
signbit,
66
sqrt,
7+
pow,
8+
log10,
9+
abs,
10+
isfinite,
711
)
812
from libcpp.deque cimport deque
913
from libcpp.stack cimport stack
1014
from libcpp.unordered_map cimport unordered_map
15+
from libcpp cimport bool
16+
1117

1218
from pandas._libs.algos cimport TiebreakEnumType
1319

@@ -21,6 +27,8 @@ from numpy cimport (
2127
ndarray,
2228
)
2329

30+
31+
2432
cnp.import_array()
2533

2634
import cython
@@ -724,6 +732,55 @@ cdef float64_t calc_kurt(int64_t minp, int64_t nobs,
724732

725733
return result
726734

735+
cdef void update_sum_of_window( float64_t val,
736+
float64_t **x_value,
737+
float64_t **comp_value,
738+
int power_of_element,
739+
bool add_mode, #1 for add_kurt, 0 for remove_kurt
740+
) noexcept nogil:
741+
742+
cdef:
743+
float64_t val_raised, new_sum
744+
bool val_length_flag, x_length_flag
745+
746+
if add_mode:
747+
val_raised = pow(val, power_of_element)
748+
else:
749+
val_raised = -pow(val, power_of_element)
750+
751+
x_length_flag = abs(log10(abs(x_value[0][0]))) > 15 and isfinite(abs(log10(abs(x_value[0][0])))) == 1
752+
val_length_flag = abs(log10(abs(val_raised))) > 15 and isfinite(abs(log10(abs(val_raised)))) == 1
753+
754+
# We'll try to maintain comp_value as the counter for numbers <1e15
755+
756+
if x_length_flag and val_length_flag:
757+
#Both > 1e15 or < 1-e15
758+
x_value[0][0] += val_raised
759+
# printf("Both > 1e15\n")
760+
761+
elif x_length_flag:
762+
comp_value[0][0] += val_raised
763+
# printf("x_flag\n")
764+
765+
766+
elif val_length_flag:
767+
comp_value[0][0] += x_value[0][0]
768+
x_value[0][0] = val_raised
769+
# printf("val_flag\n")
770+
771+
else:
772+
#Neither are >1e15/<1e-15, safe to proceed
773+
x_value[0][0] += val_raised
774+
775+
if comp_value[0][0] != 0:
776+
x_value[0][0] += comp_value[0][0]
777+
comp_value[0][0] = 0
778+
779+
printf("%.25g\n", x_value[0][0])
780+
781+
782+
783+
727784

728785
cdef void add_kurt(float64_t val, int64_t *nobs,
729786
float64_t *x, float64_t *xx,
@@ -736,29 +793,15 @@ cdef void add_kurt(float64_t val, int64_t *nobs,
736793
float64_t *prev_value
737794
) noexcept nogil:
738795
""" add a value from the kurotic calc """
739-
cdef:
740-
float64_t y, t
741796

742797
# Not NaN
743798
if val == val:
744799
nobs[0] = nobs[0] + 1
745800

746-
y = val - compensation_x[0]
747-
t = x[0] + y
748-
compensation_x[0] = t - x[0] - y
749-
x[0] = t
750-
y = val * val - compensation_xx[0]
751-
t = xx[0] + y
752-
compensation_xx[0] = t - xx[0] - y
753-
xx[0] = t
754-
y = val * val * val - compensation_xxx[0]
755-
t = xxx[0] + y
756-
compensation_xxx[0] = t - xxx[0] - y
757-
xxx[0] = t
758-
y = val * val * val * val - compensation_xxxx[0]
759-
t = xxxx[0] + y
760-
compensation_xxxx[0] = t - xxxx[0] - y
761-
xxxx[0] = t
801+
update_sum_of_window(val, &x, &compensation_x, 1, 1)
802+
update_sum_of_window(val, &xx, &compensation_xx, 2, 1)
803+
update_sum_of_window(val, &xxx, &compensation_xxx, 3, 1)
804+
update_sum_of_window(val, &xxxx, &compensation_xxxx, 4, 1)
762805

763806
# GH#42064, record num of same values to remove floating point artifacts
764807
if val == prev_value[0]:
@@ -768,7 +811,6 @@ cdef void add_kurt(float64_t val, int64_t *nobs,
768811
num_consecutive_same_value[0] = 1
769812
prev_value[0] = val
770813

771-
772814
cdef void remove_kurt(float64_t val, int64_t *nobs,
773815
float64_t *x, float64_t *xx,
774816
float64_t *xxx, float64_t *xxxx,
@@ -777,40 +819,25 @@ cdef void remove_kurt(float64_t val, int64_t *nobs,
777819
float64_t *compensation_xxx,
778820
float64_t *compensation_xxxx) noexcept nogil:
779821
""" remove a value from the kurotic calc """
780-
cdef:
781-
float64_t y, t
782822

783823
# Not NaN
784824
if val == val:
785825
nobs[0] = nobs[0] - 1
786826

787-
y = - val - compensation_x[0]
788-
t = x[0] + y
789-
compensation_x[0] = t - x[0] - y
790-
x[0] = t
791-
y = - val * val - compensation_xx[0]
792-
t = xx[0] + y
793-
compensation_xx[0] = t - xx[0] - y
794-
xx[0] = t
795-
y = - val * val * val - compensation_xxx[0]
796-
t = xxx[0] + y
797-
compensation_xxx[0] = t - xxx[0] - y
798-
xxx[0] = t
799-
y = - val * val * val * val - compensation_xxxx[0]
800-
t = xxxx[0] + y
801-
compensation_xxxx[0] = t - xxxx[0] - y
802-
xxxx[0] = t
803-
827+
update_sum_of_window(val, &x, &compensation_x, 1, 0)
828+
update_sum_of_window(val, &xx, &compensation_xx, 2, 0)
829+
update_sum_of_window(val, &xxx, &compensation_xxx, 3, 0)
830+
update_sum_of_window(val, &xxxx, &compensation_xxxx, 4, 0)
804831

805832
def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
806833
ndarray[int64_t] end, int64_t minp) -> np.ndarray:
807834
cdef:
808835
Py_ssize_t i, j
809836
float64_t val, mean_val, min_val, sum_val = 0
810-
float64_t compensation_xxxx_add, compensation_xxxx_remove
811-
float64_t compensation_xxx_remove, compensation_xxx_add
812-
float64_t compensation_xx_remove, compensation_xx_add
813-
float64_t compensation_x_remove, compensation_x_add
837+
float64_t compensation_xxxx
838+
float64_t compensation_xxx
839+
float64_t compensation_xx
840+
float64_t compensation_x
814841
float64_t x, xx, xxx, xxxx
815842
float64_t prev_value
816843
int64_t nobs, s, e, num_consecutive_same_value
@@ -843,6 +870,7 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
843870

844871
s = start[i]
845872
e = end[i]
873+
printf("\n%d| S: %d, E: %d\n", i, s, e)
846874

847875
# Over the first window, observations can only be added
848876
# never removed
@@ -851,17 +879,19 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
851879
prev_value = values[s]
852880
num_consecutive_same_value = 0
853881

854-
compensation_xxxx_add = compensation_xxxx_remove = 0
855-
compensation_xxx_remove = compensation_xxx_add = 0
856-
compensation_xx_remove = compensation_xx_add = 0
857-
compensation_x_remove = compensation_x_add = 0
882+
compensation_xxxx = 0
883+
compensation_xxx = 0
884+
compensation_xx = 0
885+
compensation_x = 0
858886
x = xx = xxx = xxxx = 0
859887
nobs = 0
860888
for j in range(s, e):
861889
add_kurt(values_copy[j], &nobs, &x, &xx, &xxx, &xxxx,
862-
&compensation_x_add, &compensation_xx_add,
863-
&compensation_xxx_add, &compensation_xxxx_add,
890+
&compensation_x, &compensation_xx,
891+
&compensation_xxx, &compensation_xxxx,
864892
&num_consecutive_same_value, &prev_value)
893+
printf(" %g|A|x: %g, xx: %g, xxx: %g, xxxx: %g, num_cons: %ld\n", values_copy[j], x,xx,xxx,xxxx, num_consecutive_same_value)
894+
865895

866896
else:
867897

@@ -870,15 +900,19 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
870900
# calculate deletes
871901
for j in range(start[i - 1], s):
872902
remove_kurt(values_copy[j], &nobs, &x, &xx, &xxx, &xxxx,
873-
&compensation_x_remove, &compensation_xx_remove,
874-
&compensation_xxx_remove, &compensation_xxxx_remove)
903+
&compensation_x, &compensation_xx,
904+
&compensation_xxx, &compensation_xxxx)
905+
906+
printf(" %g|R|x: %g, xx: %g, xxx: %g, xxxx: %g, num_cons: %ld\n", values_copy[j], x,xx,xxx,xxxx, num_consecutive_same_value)
875907

876908
# calculate adds
877909
for j in range(end[i - 1], e):
878910
add_kurt(values_copy[j], &nobs, &x, &xx, &xxx, &xxxx,
879-
&compensation_x_add, &compensation_xx_add,
880-
&compensation_xxx_add, &compensation_xxxx_add,
911+
&compensation_x, &compensation_xx,
912+
&compensation_xxx, &compensation_xxxx,
881913
&num_consecutive_same_value, &prev_value)
914+
printf(" %g|A|x: %g, xx: %g, xxx: %g, xxxx: %g, num_cons: %ld\n", values_copy[j], x,xx,xxx,xxxx, num_consecutive_same_value)
915+
882916

883917
output[i] = calc_kurt(minp, nobs, x, xx, xxx, xxxx,
884918
num_consecutive_same_value)
@@ -890,6 +924,8 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
890924
xxx = 0.0
891925
xxxx = 0.0
892926

927+
print("\n",output,"\n----------------------------------------------------------------------------")
928+
893929
return output
894930

895931

0 commit comments

Comments
 (0)