Skip to content

Commit 6f8c859

Browse files
fix: numerical instability issues
1 parent 9aab379 commit 6f8c859

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

pandas/_libs/algos.pyx

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
cimport cython
22
from cython cimport Py_ssize_t
3-
from cython.parallel cimport (
4-
prange,
5-
)
3+
from cython.parallel cimport prange
64
from libc.math cimport (
75
fabs,
86
sqrt,
@@ -365,7 +363,7 @@ def nancorr(
365363
bint no_nans
366364
int64_t nobs = 0
367365
float64_t mean, ssqd, val
368-
float64_t vx, vy, dx, dy, meanx, meany, divisor, ssqdmx, ssqdmy, covxy
366+
float64_t vx, vy, dx, dy, meanx, meany, divisor, ssqdmx, ssqdmy, covxy, corr_val
369367

370368
N, K = (<object>mat).shape
371369
if minp is None:
@@ -393,7 +391,6 @@ def nancorr(
393391
means[j] = mean
394392
ssqds[j] = ssqd
395393

396-
# ONLY CHANGE: Add parallel option to the main correlation loop
397394
if use_parallel:
398395
for xi in prange(K, schedule="dynamic", nogil=True):
399396
for yi in range(xi + 1):
@@ -427,7 +424,19 @@ def nancorr(
427424
else:
428425
divisor = (nobs - 1.0) if cov else sqrt(ssqdmx * ssqdmy)
429426
if divisor != 0:
430-
result[xi, yi] = result[yi, xi] = covxy / divisor
427+
if cov:
428+
result[xi, yi] = result[yi, xi] = covxy / divisor
429+
else:
430+
# ensure that diagonal is exactly 1.0
431+
if xi == yi:
432+
result[xi, yi] = 1.0
433+
else:
434+
corr_val = covxy / divisor
435+
if corr_val > 1.0:
436+
corr_val = 1.0
437+
elif corr_val < -1.0:
438+
corr_val = -1.0
439+
result[xi, yi] = result[yi, xi] = corr_val
431440
else:
432441
result[xi, yi] = result[yi, xi] = NaN
433442
else:
@@ -464,7 +473,19 @@ def nancorr(
464473
else:
465474
divisor = (nobs - 1.0) if cov else sqrt(ssqdmx * ssqdmy)
466475
if divisor != 0:
467-
result[xi, yi] = result[yi, xi] = covxy / divisor
476+
if cov:
477+
result[xi, yi] = result[yi, xi] = covxy / divisor
478+
else:
479+
# For correlation, ensure diagonal is exactly 1.0
480+
if xi == yi:
481+
result[xi, yi] = 1.0
482+
else:
483+
corr_val = covxy / divisor
484+
if corr_val > 1.0:
485+
corr_val = 1.0
486+
elif corr_val < -1.0:
487+
corr_val = -1.0
488+
result[xi, yi] = result[yi, xi] = corr_val
468489
else:
469490
result[xi, yi] = result[yi, xi] = NaN
470491

0 commit comments

Comments
 (0)