Skip to content

Commit cb54fef

Browse files
committed
Update scores only when new samples are available
1 parent 856f567 commit cb54fef

File tree

1 file changed

+50
-18
lines changed

1 file changed

+50
-18
lines changed

pysteps/verification/detcontscores.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def det_cont_fct(pred, obs, scores="", axis=None, conditioning=None, thr=0.0):
7979
conditioning="single", only pairs with either pred or obs > thr are
8080
included. With conditioning="double", only pairs with both pred and
8181
obs > thr are included.
82-
82+
8383
thr : float
8484
Optional threshold value for conditioning. Defaults to 0.
8585
@@ -130,10 +130,14 @@ def get_iterable(x):
130130
# split between online and offline scores
131131
loffline = ["scatter", "corr_s"]
132132
onscores = [
133-
score for score in scores if str(score).lower() not in loffline or score == ""
133+
score
134+
for score in scores
135+
if str(score).lower() not in loffline or score == ""
134136
]
135137
offscores = [
136-
score for score in scores if str(score).lower() in loffline or score == ""
138+
score
139+
for score in scores
140+
if str(score).lower() in loffline or score == ""
137141
]
138142

139143
# unique lists
@@ -364,20 +368,20 @@ def det_cont_fct_accum(err, pred, obs):
364368
mpred = mpred.squeeze()
365369

366370
# update variances
367-
err["vobs"] = _parallel_var(err["mobs"], err["n"], err["vobs"], mobs, n, vobs)
368-
err["vpred"] = _parallel_var(err["mpred"], err["n"], err["vpred"], mpred, n, vpred)
371+
_parallel_var(err["mobs"], err["n"], err["vobs"], mobs, n, vobs)
372+
_parallel_var(err["mpred"], err["n"], err["vpred"], mpred, n, vpred)
369373

370374
# update covariance
371-
err["cov"] = _parallel_cov(
375+
_parallel_cov(
372376
err["cov"], err["mobs"], err["mpred"], err["n"], cov, mobs, mpred, n
373377
)
374378

375379
# update means
376-
err["mobs"] = _parallel_mean(err["mobs"], err["n"], mobs, n)
377-
err["mpred"] = _parallel_mean(err["mpred"], err["n"], mpred, n)
378-
err["me"] = _parallel_mean(err["me"], err["n"], me, n)
379-
err["mse"] = _parallel_mean(err["mse"], err["n"], mse, n)
380-
err["mae"] = _parallel_mean(err["mae"], err["n"], mae, n)
380+
_parallel_mean(err["mobs"], err["n"], mobs, n)
381+
_parallel_mean(err["mpred"], err["n"], mpred, n)
382+
_parallel_mean(err["me"], err["n"], me, n)
383+
_parallel_mean(err["mse"], err["n"], mse, n)
384+
_parallel_mean(err["mae"], err["n"], mae, n)
381385

382386
# update number of samples
383387
err["n"] += n
@@ -495,25 +499,53 @@ def get_iterable(x):
495499

496500

497501
def _parallel_mean(avg_a, count_a, avg_b, count_b):
498-
return (count_a * avg_a + count_b * avg_b) / (count_a + count_b)
502+
"""Update avg_a with avg_b.
503+
"""
504+
idx = count_b > 0
505+
avg_a[idx] = (count_a[idx] * avg_a[idx] + count_b[idx] * avg_b[idx]) / (
506+
count_a[idx] + count_b[idx]
507+
)
499508

500509

501510
def _parallel_var(avg_a, count_a, var_a, avg_b, count_b, var_b):
502-
# source: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
511+
"""Update var_a with var_b.
512+
source: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
513+
"""
514+
idx = count_b > 0
503515
delta = avg_b - avg_a
504516
m_a = var_a * count_a
505517
m_b = var_b * count_b
506-
M2 = m_a + m_b + delta ** 2 * count_a * count_b / (count_a + count_b)
507-
return M2 / (count_a + count_b)
518+
var_a[idx] = (
519+
m_a[idx]
520+
+ m_b[idx]
521+
+ delta[idx] ** 2
522+
* count_a[idx]
523+
* count_b[idx]
524+
/ (count_a[idx] + count_b[idx])
525+
)
526+
var_a[idx] = var_a[idx] / (count_a[idx] + count_b[idx])
508527

509528

510-
def _parallel_cov(cov_a, avg_xa, avg_ya, count_a, cov_b, avg_xb, avg_yb, count_b):
529+
def _parallel_cov(
530+
cov_a, avg_xa, avg_ya, count_a, cov_b, avg_xb, avg_yb, count_b
531+
):
532+
"""Update cov_a with cov_b.
533+
"""
534+
idx = count_b > 0
511535
deltax = avg_xb - avg_xa
512536
deltay = avg_yb - avg_ya
513537
c_a = cov_a * count_a
514538
c_b = cov_b * count_b
515-
C2 = c_a + c_b + deltax * deltay * count_a * count_b / (count_a + count_b)
516-
return C2 / (count_a + count_b)
539+
cov_a[idx] = (
540+
c_a[idx]
541+
+ c_b[idx]
542+
+ deltax[idx]
543+
* deltay[idx]
544+
* count_a[idx]
545+
* count_b[idx]
546+
/ (count_a[idx] + count_b[idx])
547+
)
548+
cov_a[idx] = cov_a[idx] / (count_a[idx] + count_b[idx])
517549

518550

519551
def _uniquelist(mylist):

0 commit comments

Comments
 (0)