@@ -79,7 +79,7 @@ def det_cont_fct(pred, obs, scores="", axis=None, conditioning=None, thr=0.0):
7979 conditioning="single", only pairs with either pred or obs > thr are
8080 included. With conditioning="double", only pairs with both pred and
8181 obs > thr are included.
82-
82+
8383 thr : float
8484 Optional threshold value for conditioning. Defaults to 0.
8585
@@ -130,10 +130,14 @@ def get_iterable(x):
130130 # split between online and offline scores
131131 loffline = ["scatter" , "corr_s" ]
132132 onscores = [
133- score for score in scores if str (score ).lower () not in loffline or score == ""
133+ score
134+ for score in scores
135+ if str (score ).lower () not in loffline or score == ""
134136 ]
135137 offscores = [
136- score for score in scores if str (score ).lower () in loffline or score == ""
138+ score
139+ for score in scores
140+ if str (score ).lower () in loffline or score == ""
137141 ]
138142
139143 # unique lists
@@ -364,20 +368,20 @@ def det_cont_fct_accum(err, pred, obs):
364368 mpred = mpred .squeeze ()
365369
366370 # update variances
367- err [ "vobs" ] = _parallel_var (err ["mobs" ], err ["n" ], err ["vobs" ], mobs , n , vobs )
368- err [ "vpred" ] = _parallel_var (err ["mpred" ], err ["n" ], err ["vpred" ], mpred , n , vpred )
371+ _parallel_var (err ["mobs" ], err ["n" ], err ["vobs" ], mobs , n , vobs )
372+ _parallel_var (err ["mpred" ], err ["n" ], err ["vpred" ], mpred , n , vpred )
369373
370374 # update covariance
371- err [ "cov" ] = _parallel_cov (
375+ _parallel_cov (
372376 err ["cov" ], err ["mobs" ], err ["mpred" ], err ["n" ], cov , mobs , mpred , n
373377 )
374378
375379 # update means
376- err [ "mobs" ] = _parallel_mean (err ["mobs" ], err ["n" ], mobs , n )
377- err [ "mpred" ] = _parallel_mean (err ["mpred" ], err ["n" ], mpred , n )
378- err [ "me" ] = _parallel_mean (err ["me" ], err ["n" ], me , n )
379- err [ "mse" ] = _parallel_mean (err ["mse" ], err ["n" ], mse , n )
380- err [ "mae" ] = _parallel_mean (err ["mae" ], err ["n" ], mae , n )
380+ _parallel_mean (err ["mobs" ], err ["n" ], mobs , n )
381+ _parallel_mean (err ["mpred" ], err ["n" ], mpred , n )
382+ _parallel_mean (err ["me" ], err ["n" ], me , n )
383+ _parallel_mean (err ["mse" ], err ["n" ], mse , n )
384+ _parallel_mean (err ["mae" ], err ["n" ], mae , n )
381385
382386 # update number of samples
383387 err ["n" ] += n
@@ -495,25 +499,53 @@ def get_iterable(x):
495499
496500
497501def _parallel_mean (avg_a , count_a , avg_b , count_b ):
498- return (count_a * avg_a + count_b * avg_b ) / (count_a + count_b )
502+ """Update avg_a with avg_b.
503+ """
504+ idx = count_b > 0
505+ avg_a [idx ] = (count_a [idx ] * avg_a [idx ] + count_b [idx ] * avg_b [idx ]) / (
506+ count_a [idx ] + count_b [idx ]
507+ )
499508
500509
501510def _parallel_var (avg_a , count_a , var_a , avg_b , count_b , var_b ):
502- # source: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
511+ """Update var_a with var_b.
512+ source: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
513+ """
514+ idx = count_b > 0
503515 delta = avg_b - avg_a
504516 m_a = var_a * count_a
505517 m_b = var_b * count_b
506- M2 = m_a + m_b + delta ** 2 * count_a * count_b / (count_a + count_b )
507- return M2 / (count_a + count_b )
518+ var_a [idx ] = (
519+ m_a [idx ]
520+ + m_b [idx ]
521+ + delta [idx ] ** 2
522+ * count_a [idx ]
523+ * count_b [idx ]
524+ / (count_a [idx ] + count_b [idx ])
525+ )
526+ var_a [idx ] = var_a [idx ] / (count_a [idx ] + count_b [idx ])
508527
509528
510- def _parallel_cov (cov_a , avg_xa , avg_ya , count_a , cov_b , avg_xb , avg_yb , count_b ):
529+ def _parallel_cov (
530+ cov_a , avg_xa , avg_ya , count_a , cov_b , avg_xb , avg_yb , count_b
531+ ):
532+ """Update cov_a with cov_b.
533+ """
534+ idx = count_b > 0
511535 deltax = avg_xb - avg_xa
512536 deltay = avg_yb - avg_ya
513537 c_a = cov_a * count_a
514538 c_b = cov_b * count_b
515- C2 = c_a + c_b + deltax * deltay * count_a * count_b / (count_a + count_b )
516- return C2 / (count_a + count_b )
539+ cov_a [idx ] = (
540+ c_a [idx ]
541+ + c_b [idx ]
542+ + deltax [idx ]
543+ * deltay [idx ]
544+ * count_a [idx ]
545+ * count_b [idx ]
546+ / (count_a [idx ] + count_b [idx ])
547+ )
548+ cov_a [idx ] = cov_a [idx ] / (count_a [idx ] + count_b [idx ])
517549
518550
519551def _uniquelist (mylist ):
0 commit comments