Skip to content

Commit 42381ba

Browse files
authored
Merge branch 'develop' into dependabot/pip/transformers-4.48.0
2 parents ae8b0eb + bb77008 commit 42381ba

38 files changed

+1361
-598
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676

7777
### Changed
7878

79+
- Switched all semi-value coefficients and sampler weights to log-space in
80+
order to avoid overflows
81+
[PR #643](https://github.com/aai-institute/pyDVL/pull/643)
7982
- Updated and rewrote some of the MSR banzhaf notebook
8083
[PR #641](https://github.com/aai-institute/pyDVL/pull/641)
8184
- Updated Least-Core notebook
@@ -84,9 +87,6 @@
8487
thus subsuming Variance-Reduced stratified sampling into a unified framework.
8588
Implemented the heuristics proposed in that paper
8689
[PR #641](https://github.com/aai-institute/pyDVL/pull/641)
87-
- Changed the way semi-value coefficients are composed with sampler weights in
88-
order to avoid `OverflowError` for very small or large values
89-
[PR #639](https://github.com/aai-institute/pyDVL/pull/639)
9090
- Uniformly distribute test points across processes for KNNShapley. Fail for
9191
`GroupedDataset` [PR #632](https://github.com/aai-institute/pyDVL/pull/632)
9292
- Introduced the concept of logical vs data indices for `Dataset`, and

src/pydvl/utils/numeric.py

Lines changed: 188 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
2-
This module contains routines for numerical computations used across the
3-
library.
2+
This module contains routines for numerical computations used across the library.
43
"""
54

65
from __future__ import annotations
@@ -10,29 +9,31 @@
109
Collection,
1110
Generator,
1211
Iterator,
13-
List,
1412
Optional,
1513
Sequence,
16-
Tuple,
1714
TypeVar,
18-
overload,
1915
)
2016

2117
import numpy as np
2218
from numpy.typing import NDArray
19+
from scipy.special import gammaln
2320

2421
from pydvl.utils.types import Seed
2522

2623
__all__ = [
2724
"complement",
28-
"running_moments",
25+
"logcomb",
26+
"logexp",
27+
"log_running_moments",
28+
"logsumexp_two",
2929
"num_samples_permutation_hoeffding",
3030
"powerset",
3131
"random_matrix_with_condition_number",
3232
"random_subset",
3333
"random_powerset",
3434
"random_powerset_label_min",
3535
"random_subset_of_size",
36+
"running_moments",
3637
"top_k_value_accuracy",
3738
]
3839

@@ -202,7 +203,7 @@ def random_powerset_label_min(
202203
unique_labels = np.unique(labels)
203204

204205
while True:
205-
subsets: List[NDArray[T]] = []
206+
subsets: list[NDArray[T]] = []
206207
for label in unique_labels:
207208
label_indices = np.asarray(np.where(labels == label)[0])
208209
subset_size = int(
@@ -291,53 +292,51 @@ def random_matrix_with_condition_number(
291292
return P
292293

293294

294-
@overload
295-
def running_moments(
296-
previous_avg: float, previous_variance: float, count: int, new_value: float
297-
) -> Tuple[float, float]: ...
298-
299-
300-
@overload
301-
def running_moments(
302-
previous_avg: NDArray[np.float64],
303-
previous_variance: NDArray[np.float64],
304-
count: int,
305-
new_value: NDArray[np.float64],
306-
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]: ...
307-
308-
309295
def running_moments(
310-
previous_avg: float | NDArray[np.float64],
311-
previous_variance: float | NDArray[np.float64],
296+
previous_avg: float,
297+
previous_variance: float,
312298
count: int,
313-
new_value: float | NDArray[np.float64],
314-
) -> Tuple[float | NDArray[np.float64], float | NDArray[np.float64]]:
315-
"""Uses Welford's algorithm to calculate the running average and variance of
316-
a set of numbers.
299+
new_value: float,
300+
unbiased: bool = True,
301+
) -> tuple[float, float]:
302+
"""Calculates running average and variance of a series of numbers.
317303
318-
See [Welford's algorithm in wikipedia](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
304+
See [Welford's algorithm in
305+
wikipedia](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
319306
320307
!!! Warning
321308
This is not really using Welford's correction for numerical stability
322309
for the variance. (FIXME)
323310
324311
!!! Todo
325-
This could be generalised to arbitrary moments. See [this paper](https://www.osti.gov/biblio/1028931)
312+
This could be generalised to arbitrary moments. See [this
313+
paper](https://www.osti.gov/biblio/1028931)
326314
327315
Args:
328-
previous_avg: average value at previous step
329-
previous_variance: variance at previous step
330-
count: number of points seen so far
331-
new_value: new value in the series of numbers
332-
316+
previous_avg: average value at previous step.
317+
previous_variance: variance at previous step.
318+
count: number of points seen so far,
319+
new_value: new value in the series of numbers.
320+
unbiased: whether to use the unbiased variance estimator (same as `np.var` with
321+
`ddof=1`).
333322
Returns:
334323
new_average, new_variance, calculated with the new count
335324
"""
336-
# broadcasted operations seem not to be supported by mypy, so we ignore the type
337-
new_average = (new_value + count * previous_avg) / (count + 1) # type: ignore
338-
new_variance = previous_variance + (
339-
(new_value - previous_avg) * (new_value - new_average) - previous_variance
340-
) / (count + 1)
325+
delta = new_value - previous_avg
326+
new_average = previous_avg + delta / (count + 1)
327+
328+
if unbiased:
329+
if count > 0:
330+
new_variance = (
331+
previous_variance + delta**2 / (count + 1) - previous_variance / count
332+
)
333+
else:
334+
new_variance = 0.0
335+
else:
336+
new_variance = previous_variance + (
337+
delta * (new_value - new_average) - previous_variance
338+
) / (count + 1)
339+
341340
return new_average, new_variance
342341

343342

@@ -359,3 +358,152 @@ def top_k_value_accuracy(
359358
top_k_pred_values = np.argsort(y_pred)[-k:]
360359
top_k_accuracy = len(np.intersect1d(top_k_exact_values, top_k_pred_values)) / k
361360
return top_k_accuracy
361+
362+
363+
def logcomb(n: int, k: int) -> float:
364+
r"""Computes the log of the binomial coefficient (n choose k).
365+
366+
$$
367+
\begin{array}{rcl}
368+
\log\binom{n}{k} & = & \log(n!) - \log(k!) - \log((n-k)!) \\
369+
& = & \log\Gamma(n+1) - \log\Gamma(k+1) - \log\Gamma(n-k+1).
370+
\end{array}
371+
$$
372+
373+
Args:
374+
n: Total number of elements
375+
k: Number of elements to choose
376+
Returns:
377+
The log of the binomial coefficient
378+
"""
379+
if k < 0 or k > n or n < 0:
380+
raise ValueError(f"Invalid arguments: n={n}, k={k}")
381+
return float(gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1))
382+
383+
384+
def logexp(x: float, a: float) -> float:
385+
"""Computes log(x^a).
386+
387+
Args:
388+
x: Base
389+
a: Exponent
390+
Returns
391+
a * log(x)
392+
"""
393+
return float(a * np.log(x))
394+
395+
396+
def logsumexp_two(log_a: float, log_b: float) -> float:
397+
r"""Numerically stable computation of log(exp(log_a) + exp(log_b)).
398+
399+
Uses standard log sum exp trick:
400+
401+
$$
402+
\log(\exp(\log a) + \exp(\log b)) = m + \log(\exp(\log a - m) + \exp(\log b - m)),
403+
$$
404+
405+
where $m = \max(\log a, \log b)$.
406+
407+
Args:
408+
log_a: Log of the first value
409+
log_b: Log of the second value
410+
Returns:
411+
The log of the sum of the exponentials
412+
"""
413+
if log_a == -np.inf:
414+
return log_b
415+
if log_b == -np.inf:
416+
return log_a
417+
m = max(log_a, log_b)
418+
return float(m + np.log(np.exp(log_a - m) + np.exp(log_b - m)))
419+
420+
421+
def log_running_moments(
422+
previous_log_sum_pos: float,
423+
previous_log_sum_neg: float,
424+
previous_log_sum2: float,
425+
count: int,
426+
new_log_value: float,
427+
new_sign: int,
428+
unbiased: bool = True,
429+
) -> tuple[float, float, float, float, float]:
430+
"""
431+
Update running moments when the new value is provided in log space,
432+
allowing for negative values via an explicit sign.
433+
434+
Here the actual value is x = new_sign * exp(new_log_value). Rather than
435+
updating the arithmetic sum S = sum(x) and S2 = sum(x^2) directly, we maintain:
436+
437+
L_S+ = log(sum_{i: x_i >= 0} x_i)
438+
L_S- = log(sum_{i: x_i < 0} |x_i|)
439+
L_S2 = log(sum_i x_i^2)
440+
441+
The running mean is then computed as:
442+
443+
mean = exp(L_S+) - exp(L_S-)
444+
445+
and the second moment is:
446+
447+
second_moment = exp(L_S2 - log(count))
448+
449+
so that the variance is:
450+
451+
variance = second_moment - mean^2
452+
453+
For the unbiased (sample) estimator, we scale the variance by count/(count-1)
454+
when count > 1 (and define variance = 0 when count == 1).
455+
456+
Args:
457+
previous_log_sum_pos: running log(sum of positive contributions), or -inf if none.
458+
previous_log_sum_neg: running log(sum of negative contributions in absolute
459+
value), or -inf if none.
460+
previous_log_sum2: running log(sum of squares) so far (or -inf if none).
461+
count: number of points processed so far.
462+
new_log_value: log(|x_new|), where x_new is the new value.
463+
new_sign: sign of the new value (should be +1, 0, or -1).
464+
unbiased: if True, compute the unbiased estimator of the variance.
465+
466+
Returns:
467+
new_mean: running mean in the linear domain.
468+
new_variance: running variance in the linear domain.
469+
new_log_sum_pos: updated running log(sum of positive contributions).
470+
new_log_sum_neg: updated running log(sum of negative contributions).
471+
new_log_sum2: updated running log(sum of squares).
472+
new_count: updated count.
473+
"""
474+
475+
if count == 0:
476+
if new_sign >= 0:
477+
new_log_sum_pos = new_log_value
478+
new_log_sum_neg = -np.inf # No negative contribution yet.
479+
else:
480+
new_log_sum_pos = -np.inf
481+
new_log_sum_neg = new_log_value
482+
new_log_sum2 = 2 * new_log_value
483+
else:
484+
if new_sign >= 0:
485+
new_log_sum_pos = logsumexp_two(previous_log_sum_pos, new_log_value)
486+
new_log_sum_neg = previous_log_sum_neg
487+
else:
488+
new_log_sum_neg = logsumexp_two(previous_log_sum_neg, new_log_value)
489+
new_log_sum_pos = previous_log_sum_pos
490+
new_log_sum2 = logsumexp_two(previous_log_sum2, 2 * new_log_value)
491+
new_count = count + 1
492+
493+
# Compute 1st and 2nd moments in the linear domain.
494+
pos_sum = np.exp(new_log_sum_pos) if new_log_sum_pos != -np.inf else 0.0
495+
neg_sum = np.exp(new_log_sum_neg) if new_log_sum_neg != -np.inf else 0.0
496+
new_mean = (pos_sum - neg_sum) / new_count
497+
498+
second_moment = np.exp(new_log_sum2 - np.log(new_count))
499+
500+
# Compute variance using either the population or unbiased estimator.
501+
if unbiased:
502+
if new_count > 1:
503+
new_variance = new_count / (new_count - 1) * (second_moment - new_mean**2)
504+
else:
505+
new_variance = 0.0
506+
else:
507+
new_variance = second_moment - new_mean**2
508+
509+
return new_mean, new_variance, new_log_sum_pos, new_log_sum_neg, new_log_sum2

src/pydvl/valuation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from pydvl.valuation.methods import *
33
from pydvl.valuation.samplers import *
44
from pydvl.valuation.scorers import *
5+
from pydvl.valuation.stopping import *
56
from pydvl.valuation.utility import *

src/pydvl/valuation/methods/beta_shapley.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def __init__(
2828
self.alpha = alpha
2929
self.beta = beta
3030
self.const = sp.special.beta(alpha, beta)
31+
self.log_const = sp.special.betaln(alpha, beta)
3132

32-
def coefficient(self, n: int, k: int, weight: float) -> float:
33+
def log_coefficient(self, n: int, k: int) -> float:
3334
j = k + 1
34-
w = sp.special.beta(j + self.beta - 1, n - j + self.alpha) / self.const
35-
# return math.comb(n - 1, j - 1) * w * n * other
36-
return float(w) * weight
35+
return float(
36+
sp.special.betaln(j + self.beta - 1, n - j + self.alpha) - self.log_const
37+
)

src/pydvl/valuation/methods/classwise_shapley.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ def fit(self, data: Dataset):
149149
self.is_done.reset()
150150
self.utility.training_data = data
151151

152-
sample_generator = self.sampler.from_data(data)
153152
strategy = self.sampler.make_strategy(self.utility)
153+
updater = self.sampler.result_updater(self.result)
154154
processor = delayed(strategy.process)
155155

156+
sample_generator = self.sampler.from_data(data)
157+
156158
with Parallel(return_as="generator_unordered") as parallel:
157159
with make_parallel_flag() as flag:
158160
delayed_evals = parallel(
@@ -162,7 +164,7 @@ def fit(self, data: Dataset):
162164

163165
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
164166
for evaluation in batch:
165-
self.result.update(evaluation.idx, evaluation.update)
167+
self.result = updater(evaluation)
166168
if self.is_done(self.result):
167169
flag.set()
168170
self.sampler.interrupt()
@@ -211,6 +213,6 @@ def _normalize(self) -> ValuationResult:
211213

212214
sigma = np.sum(self.result.values[indices_label_set])
213215
if sigma != 0:
214-
self.result.scale(in_class_acc / sigma, indices=indices_label_set)
216+
self.result.scale(in_class_acc / sigma, data_indices=indices_label_set)
215217

216218
return self.result

src/pydvl/valuation/methods/data_banzhaf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
6388–6421. PMLR, 2023.
2626
"""
2727

28+
import numpy as np
29+
2830
from pydvl.valuation.methods.semivalue import SemivalueValuation
2931

3032
__all__ = ["DataBanzhafValuation"]
@@ -35,5 +37,5 @@ class DataBanzhafValuation(SemivalueValuation):
3537

3638
algorithm_name = "Data-Banzhaf"
3739

38-
def coefficient(self, n: int, k: int, weight: float) -> float:
39-
return float(weight / 2 ** (n - 1))
40+
def log_coefficient(self, n: int, k: int) -> float:
41+
return float(-(n - 1) * np.log(2))

0 commit comments

Comments
 (0)