11"""
2- This module contains routines for numerical computations used across the
3- library.
2+ This module contains routines for numerical computations used across the library.
43"""
54
65from __future__ import annotations
109 Collection ,
1110 Generator ,
1211 Iterator ,
13- List ,
1412 Optional ,
1513 Sequence ,
16- Tuple ,
1714 TypeVar ,
18- overload ,
1915)
2016
2117import numpy as np
2218from numpy .typing import NDArray
19+ from scipy .special import gammaln
2320
2421from pydvl .utils .types import Seed
2522
2623__all__ = [
2724 "complement" ,
28- "running_moments" ,
25+ "logcomb" ,
26+ "logexp" ,
27+ "log_running_moments" ,
28+ "logsumexp_two" ,
2929 "num_samples_permutation_hoeffding" ,
3030 "powerset" ,
3131 "random_matrix_with_condition_number" ,
3232 "random_subset" ,
3333 "random_powerset" ,
3434 "random_powerset_label_min" ,
3535 "random_subset_of_size" ,
36+ "running_moments" ,
3637 "top_k_value_accuracy" ,
3738]
3839
@@ -202,7 +203,7 @@ def random_powerset_label_min(
202203 unique_labels = np .unique (labels )
203204
204205 while True :
205- subsets : List [NDArray [T ]] = []
206+ subsets : list [NDArray [T ]] = []
206207 for label in unique_labels :
207208 label_indices = np .asarray (np .where (labels == label )[0 ])
208209 subset_size = int (
@@ -291,53 +292,51 @@ def random_matrix_with_condition_number(
291292 return P
292293
293294
294- @overload
295- def running_moments (
296- previous_avg : float , previous_variance : float , count : int , new_value : float
297- ) -> Tuple [float , float ]: ...
298-
299-
300- @overload
301- def running_moments (
302- previous_avg : NDArray [np .float64 ],
303- previous_variance : NDArray [np .float64 ],
304- count : int ,
305- new_value : NDArray [np .float64 ],
306- ) -> Tuple [NDArray [np .float64 ], NDArray [np .float64 ]]: ...
307-
308-
309295def running_moments (
310- previous_avg : float | NDArray [ np . float64 ] ,
311- previous_variance : float | NDArray [ np . float64 ] ,
296+ previous_avg : float ,
297+ previous_variance : float ,
312298 count : int ,
313- new_value : float | NDArray [ np . float64 ] ,
314- ) -> Tuple [ float | NDArray [ np . float64 ], float | NDArray [ np . float64 ]]:
315- """Uses Welford's algorithm to calculate the running average and variance of
316- a set of numbers.
299+ new_value : float ,
300+ unbiased : bool = True ,
301+ ) -> tuple [ float , float ]:
302+ """Calculates running average and variance of a series of numbers.
317303
318- See [Welford's algorithm in wikipedia](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
304+ See [Welford's algorithm in
305+ wikipedia](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
319306
320307 !!! Warning
321308 This is not really using Welford's correction for numerical stability
322309 for the variance. (FIXME)
323310
324311 !!! Todo
325- This could be generalised to arbitrary moments. See [this paper](https://www.osti.gov/biblio/1028931)
312+ This could be generalised to arbitrary moments. See [this
313+ paper](https://www.osti.gov/biblio/1028931)
326314
327315 Args:
328- previous_avg: average value at previous step
329- previous_variance: variance at previous step
330- count: number of points seen so far
331- new_value: new value in the series of numbers
332-
316+ previous_avg: average value at previous step.
317+ previous_variance: variance at previous step.
318+ count: number of points seen so far,
319+ new_value: new value in the series of numbers.
320+ unbiased: whether to use the unbiased variance estimator (same as `np.var` with
321+ `ddof=1`).
333322 Returns:
334323 new_average, new_variance, calculated with the new count
335324 """
336- # broadcasted operations seem not to be supported by mypy, so we ignore the type
337- new_average = (new_value + count * previous_avg ) / (count + 1 ) # type: ignore
338- new_variance = previous_variance + (
339- (new_value - previous_avg ) * (new_value - new_average ) - previous_variance
340- ) / (count + 1 )
325+ delta = new_value - previous_avg
326+ new_average = previous_avg + delta / (count + 1 )
327+
328+ if unbiased :
329+ if count > 0 :
330+ new_variance = (
331+ previous_variance + delta ** 2 / (count + 1 ) - previous_variance / count
332+ )
333+ else :
334+ new_variance = 0.0
335+ else :
336+ new_variance = previous_variance + (
337+ delta * (new_value - new_average ) - previous_variance
338+ ) / (count + 1 )
339+
341340 return new_average , new_variance
342341
343342
@@ -359,3 +358,152 @@ def top_k_value_accuracy(
359358 top_k_pred_values = np .argsort (y_pred )[- k :]
360359 top_k_accuracy = len (np .intersect1d (top_k_exact_values , top_k_pred_values )) / k
361360 return top_k_accuracy
361+
362+
363+ def logcomb (n : int , k : int ) -> float :
364+ r"""Computes the log of the binomial coefficient (n choose k).
365+
366+ $$
367+ \begin{array}{rcl}
368+ \log\binom{n}{k} & = & \log(n!) - \log(k!) - \log((n-k)!) \\
369+ & = & \log\Gamma(n+1) - \log\Gamma(k+1) - \log\Gamma(n-k+1).
370+ \end{array}
371+ $$
372+
373+ Args:
374+ n: Total number of elements
375+ k: Number of elements to choose
376+ Returns:
377+ The log of the binomial coefficient
378+ """
379+ if k < 0 or k > n or n < 0 :
380+ raise ValueError (f"Invalid arguments: n={ n } , k={ k } " )
381+ return float (gammaln (n + 1 ) - gammaln (k + 1 ) - gammaln (n - k + 1 ))
382+
383+
384+ def logexp (x : float , a : float ) -> float :
385+ """Computes log(x^a).
386+
387+ Args:
388+ x: Base
389+ a: Exponent
390+ Returns
391+ a * log(x)
392+ """
393+ return float (a * np .log (x ))
394+
395+
396+ def logsumexp_two (log_a : float , log_b : float ) -> float :
397+ r"""Numerically stable computation of log(exp(log_a) + exp(log_b)).
398+
399+ Uses standard log sum exp trick:
400+
401+ $$
402+ \log(\exp(\log a) + \exp(\log b)) = m + \log(\exp(\log a - m) + \exp(\log b - m)),
403+ $$
404+
405+ where $m = \max(\log a, \log b)$.
406+
407+ Args:
408+ log_a: Log of the first value
409+ log_b: Log of the second value
410+ Returns:
411+ The log of the sum of the exponentials
412+ """
413+ if log_a == - np .inf :
414+ return log_b
415+ if log_b == - np .inf :
416+ return log_a
417+ m = max (log_a , log_b )
418+ return float (m + np .log (np .exp (log_a - m ) + np .exp (log_b - m )))
419+
420+
421+ def log_running_moments (
422+ previous_log_sum_pos : float ,
423+ previous_log_sum_neg : float ,
424+ previous_log_sum2 : float ,
425+ count : int ,
426+ new_log_value : float ,
427+ new_sign : int ,
428+ unbiased : bool = True ,
429+ ) -> tuple [float , float , float , float , float ]:
430+ """
431+ Update running moments when the new value is provided in log space,
432+ allowing for negative values via an explicit sign.
433+
434+ Here the actual value is x = new_sign * exp(new_log_value). Rather than
435+ updating the arithmetic sum S = sum(x) and S2 = sum(x^2) directly, we maintain:
436+
437+ L_S+ = log(sum_{i: x_i >= 0} x_i)
438+ L_S- = log(sum_{i: x_i < 0} |x_i|)
439+ L_S2 = log(sum_i x_i^2)
440+
441+ The running mean is then computed as:
442+
443+ mean = exp(L_S+) - exp(L_S-)
444+
445+ and the second moment is:
446+
447+ second_moment = exp(L_S2 - log(count))
448+
449+ so that the variance is:
450+
451+ variance = second_moment - mean^2
452+
453+ For the unbiased (sample) estimator, we scale the variance by count/(count-1)
454+ when count > 1 (and define variance = 0 when count == 1).
455+
456+ Args:
457+ previous_log_sum_pos: running log(sum of positive contributions), or -inf if none.
458+ previous_log_sum_neg: running log(sum of negative contributions in absolute
459+ value), or -inf if none.
460+ previous_log_sum2: running log(sum of squares) so far (or -inf if none).
461+ count: number of points processed so far.
462+ new_log_value: log(|x_new|), where x_new is the new value.
463+ new_sign: sign of the new value (should be +1, 0, or -1).
464+ unbiased: if True, compute the unbiased estimator of the variance.
465+
466+ Returns:
467+ new_mean: running mean in the linear domain.
468+ new_variance: running variance in the linear domain.
469+ new_log_sum_pos: updated running log(sum of positive contributions).
470+ new_log_sum_neg: updated running log(sum of negative contributions).
471+ new_log_sum2: updated running log(sum of squares).
472+ new_count: updated count.
473+ """
474+
475+ if count == 0 :
476+ if new_sign >= 0 :
477+ new_log_sum_pos = new_log_value
478+ new_log_sum_neg = - np .inf # No negative contribution yet.
479+ else :
480+ new_log_sum_pos = - np .inf
481+ new_log_sum_neg = new_log_value
482+ new_log_sum2 = 2 * new_log_value
483+ else :
484+ if new_sign >= 0 :
485+ new_log_sum_pos = logsumexp_two (previous_log_sum_pos , new_log_value )
486+ new_log_sum_neg = previous_log_sum_neg
487+ else :
488+ new_log_sum_neg = logsumexp_two (previous_log_sum_neg , new_log_value )
489+ new_log_sum_pos = previous_log_sum_pos
490+ new_log_sum2 = logsumexp_two (previous_log_sum2 , 2 * new_log_value )
491+ new_count = count + 1
492+
493+ # Compute 1st and 2nd moments in the linear domain.
494+ pos_sum = np .exp (new_log_sum_pos ) if new_log_sum_pos != - np .inf else 0.0
495+ neg_sum = np .exp (new_log_sum_neg ) if new_log_sum_neg != - np .inf else 0.0
496+ new_mean = (pos_sum - neg_sum ) / new_count
497+
498+ second_moment = np .exp (new_log_sum2 - np .log (new_count ))
499+
500+ # Compute variance using either the population or unbiased estimator.
501+ if unbiased :
502+ if new_count > 1 :
503+ new_variance = new_count / (new_count - 1 ) * (second_moment - new_mean ** 2 )
504+ else :
505+ new_variance = 0.0
506+ else :
507+ new_variance = second_moment - new_mean ** 2
508+
509+ return new_mean , new_variance , new_log_sum_pos , new_log_sum_neg , new_log_sum2
0 commit comments