Skip to content

Commit c56450a

Browse files
authored
Merge pull request #250 from appliedAI-Initiative/feature/ghorbani-stopping-step2
Generalised stopping criteria, truncation policies and improved data structures
2 parents fcb5165 + cbdccbe commit c56450a

36 files changed

+1848
-992
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44

5+
- Generalised stopping criteria for valuation algorithms. Improved classes
6+
`ValuationResult` and `Status` with more operations. Some minor issues fixed.
7+
[PR #252](https://github.com/appliedAI-Initiative/pyDVL/pull/250)
58
- Fixed a bug whereby `compute_shapley_values` would only spawn one process when
69
using `n_jobs=-1` and Monte Carlo methods.
710
[PR #270](https://github.com/appliedAI-Initiative/pyDVL/pull/270)

build_scripts/update_docs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ def module_template(module_qualname: str):
2424
:undoc-members:
2525
2626
----
27-
28-
Module members
29-
==============
3027
3128
.. footbibliography::
3229

docs/30-data-valuation.rst

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ definitions, but other methods are typically preferable.
174174
values = naive_loo(utility)
175175
176176
The return value of all valuation functions is an object of type
177-
:class:`~pydvl.value.results.ValuationResult`. This can be iterated over,
177+
:class:`~pydvl.value.result.ValuationResult`. This can be iterated over,
178178
indexed with integers, slices and Iterables, as well as converted to a
179179
`pandas DataFrame <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html>`_.
180180

@@ -217,11 +217,11 @@ v_u(x_i) = \frac{1}{n} \sum_{S \subseteq D \setminus \{x_i\}}
217217
values = compute_shapley_values(utility, mode="combinatorial_exact")
218218
df = values.to_dataframe(column='value')
219219
220-
We convert the return value to a
220+
We can convert the return value to a
221221
`pandas DataFrame <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html>`_
222222
and name the column with the results as `value`. Please refer to the
223223
documentation in :mod:`pydvl.value.shapley` and
224-
:class:`~pydvl.value.results.ValuationResult` for more information.
224+
:class:`~pydvl.value.result.ValuationResult` for more information.
225225

226226
Monte Carlo Combinatorial Shapley
227227
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -240,12 +240,19 @@ same pattern:
240240
model = ...
241241
data = Dataset(...)
242242
utility = Utility(model, data)
243-
values = compute_shapley_values(utility, mode="combinatorial_montecarlo")
243+
values = compute_shapley_values(
244+
utility, mode="combinatorial_montecarlo", done=MaxUpdates(1000)
245+
)
244246
df = values.to_dataframe(column='cmc')
245247
246248
The DataFrames returned by most Monte Carlo methods will contain approximate
247249
standard errors as an additional column, in this case named `cmc_stderr`.
248250

251+
Note the usage of the object :class:`~pydvl.value.stopping.MaxUpdates` as the
252+
stop condition. This is an instance of a
253+
:class:`~pydvl.value.stopping.StoppingCriterion`. Other examples are
254+
:class:`~pydvl.value.stopping.MaxTime` and :class:`~pydvl.value.stopping.StandardError`.
255+
249256

250257
Owen sampling
251258
^^^^^^^^^^^^^
@@ -281,6 +288,10 @@ sampling, and its variant *Antithetic Owen Sampling* in the documentation for th
281288
function doing the work behind the scenes:
282289
:func:`~pydvl.value.shapley.montecarlo.owen_sampling_shapley`.
283290

291+
Note that in this case we do not pass a
292+
:class:`~pydvl.value.stopping.StoppingCriterion` to the function, but instead
293+
the number of iterations and the maximum number of samples to use in the
294+
integration.
284295

285296
Permutation Shapley
286297
^^^^^^^^^^^^^^^^^^^
@@ -309,7 +320,7 @@ efficient enough to be useful in some applications.
309320
data = Dataset(...)
310321
utility = Utility(model, data)
311322
values = compute_shapley_values(
312-
u=utility, mode="truncated_montecarlo", n_iterations=100
323+
u=utility, mode="truncated_montecarlo", done=MaxUpdates(1000)
313324
)
314325
315326

requirements-dev.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
black[jupyter] == 22.10.0
2-
isort == 5.12
2+
isort == 5.12.0
33
jupyter
44
mypy == 0.982
5-
nbconvert
5+
nbconvert>=7.2.9
66
nbstripout == 0.6.1
77
bump2version
8-
pre-commit == 2.20.0
9-
pytest
8+
pre-commit==3.0.4
9+
pytest==7.2.1
1010
pytest-cov
11-
pytest-docker
11+
pytest-docker==0.12.0
1212
pytest-mock
1313
pytest-timeout
1414
ray[default] >= 0.8

src/pydvl/reporting/scores.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,18 @@
1-
from collections import OrderedDict
2-
from operator import itemgetter
3-
from typing import Dict, Iterable, Mapping, Sequence, TypeVar, Union
1+
from typing import Dict, Iterable, Union
42

53
import numpy as np
64
from numpy.typing import NDArray
75

86
from pydvl.utils import Utility, maybe_progress
9-
from pydvl.value.results import ValuationResult
7+
from pydvl.value.result import ValuationResult
108

11-
__all__ = [
12-
"sort_values",
13-
"sort_values_array",
14-
"sort_values_history",
15-
"compute_removal_score",
16-
]
17-
18-
KT = TypeVar("KT")
19-
VT = TypeVar("VT")
20-
21-
22-
def sort_values_array(values: np.ndarray) -> Dict[int, "NDArray"]:
23-
vals = np.mean(values, axis=1)
24-
return OrderedDict(sorted(enumerate(vals), key=itemgetter(1)))
25-
26-
27-
def sort_values_history(values: Mapping[KT, Sequence[VT]]) -> Dict[KT, Sequence[VT]]:
28-
"""Sorts a dict of sample_id: [values] by the last item in each list."""
29-
return OrderedDict(sorted(values.items(), key=itemgetter(1, -1)))
30-
31-
32-
def sort_values(values: Mapping[KT, VT]) -> Dict[KT, VT]:
33-
"""Sorts a dict of sample_id: value_float by value."""
34-
return OrderedDict(sorted(values.items(), key=itemgetter(1)))
9+
__all__ = ["compute_removal_score"]
3510

3611

3712
def compute_removal_score(
3813
u: Utility,
3914
values: ValuationResult,
40-
percentages: Union["NDArray", Iterable[float]],
15+
percentages: Union[NDArray[np.float_], Iterable[float]],
4116
*,
4217
remove_best: bool = False,
4318
progress: bool = False,
@@ -66,11 +41,7 @@ def compute_removal_score(
6641
# We sort in descending order if we want to remove the best values
6742
values.sort(reverse=remove_best)
6843

69-
for pct in maybe_progress(
70-
percentages,
71-
display=progress,
72-
desc="Removal Scores",
73-
):
44+
for pct in maybe_progress(percentages, display=progress, desc="Removal Scores"):
7445
n_removal = int(pct * len(u.data))
7546
indices = values.indices[n_removal:]
7647
score = u(indices)

src/pydvl/utils/caching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __call__(self, *args, **kwargs) -> T:
271271
):
272272
new_value = fun(*args, **kwargs)
273273
new_avg, new_var = running_moments(
274-
value, variance, cast(float, new_value), int(count)
274+
value, variance, int(count), cast(float, new_value)
275275
)
276276
result_dict["value"] = new_avg
277277
result_dict["count"] = count + 1

src/pydvl/utils/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def from_sklearn(
408408
train_size: float = 0.8,
409409
random_state: Optional[int] = None,
410410
stratify_by_target: bool = False,
411-
data_groups: Optional[List] = None,
411+
data_groups: Optional[Sequence] = None,
412412
) -> "GroupedDataset":
413413
"""Constructs a :class:`GroupedDataset` object from an sklearn bunch as returned by the
414414
`load_*` functions in `sklearn toy datasets
@@ -444,7 +444,7 @@ def from_arrays(
444444
train_size: float = 0.8,
445445
random_state: Optional[int] = None,
446446
stratify_by_target: bool = False,
447-
data_groups: Optional[List] = None,
447+
data_groups: Optional[Sequence] = None,
448448
) -> "Dataset":
449449
""".. versionadded:: 0.4.0
450450

src/pydvl/utils/numeric.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,22 @@ def num_samples_permutation_hoeffding(eps: float, delta: float, u_range: float)
6969

7070

7171
def random_powerset(
72-
s: NDArray[T], max_subsets: Optional[int] = None, q: float = 0.5
72+
s: NDArray[T], n_samples: Optional[int] = None, q: float = 0.5
7373
) -> Generator[NDArray[T], None, None]:
7474
"""Samples subsets from the power set of the argument, without
7575
pre-generating all subsets and in no order.
7676
7777
See `powerset()` if you wish to deterministically generate all subsets.
7878
79-
To generate subsets, `len(s)` Bernoulli draws with probability `q` are drawn.
79+
To generate subsets, `len(s)` Bernoulli draws with probability `q` are
80+
drawn.
8081
The default value of `q = 0.5` provides a uniform distribution over the
8182
power set of `s`. Other choices can be used e.g. to implement
82-
:func:`Owen sampling <pydvl.value.shapley.montecarlo.owen_sampling_shapley>`.
83+
:func:`Owen sampling
84+
<pydvl.value.shapley.montecarlo.owen_sampling_shapley>`.
8385
8486
:param s: set to sample from
85-
:param max_subsets: if set, stop the generator after this many steps.
87+
:param n_samples: if set, stop the generator after this many steps.
8688
Defaults to `np.iinfo(np.int32).max`
8789
:param q: Sampling probability for elements. The default 0.5 yields a
8890
uniform distribution over the power set of s.
@@ -99,9 +101,9 @@ def random_powerset(
99101

100102
rng = np.random.default_rng()
101103
total = 1
102-
if max_subsets is None:
103-
max_subsets = np.iinfo(np.int32).max
104-
while total <= max_subsets:
104+
if n_samples is None:
105+
n_samples = np.iinfo(np.int32).max
106+
while total <= n_samples:
105107
selection = rng.uniform(size=len(s)) > q
106108
subset = s[selection]
107109
yield subset
@@ -228,8 +230,8 @@ def linear_regression_analytical_derivative_d_x_d_theta(
228230
def running_moments(
229231
previous_avg: FloatOrArray,
230232
previous_variance: FloatOrArray,
231-
new_value: FloatOrArray,
232233
count: IntOrArray,
234+
new_value: FloatOrArray,
233235
) -> Tuple: # [FloatOrArray, FloatOrArray]:
234236
"""Uses Welford's algorithm to calculate the running average and variance of
235237
a set of numbers.
@@ -248,9 +250,9 @@ def running_moments(
248250
249251
:param previous_avg: average value at previous step
250252
:param previous_variance: variance at previous step
251-
:param new_value: new value in the series of numbers
252253
:param count: number of points seen so far
253-
:return: new_average, new_variance, calculated with the new number
254+
:param new_value: new value in the series of numbers
255+
:return: new_average, new_variance, calculated with the new count
254256
"""
255257
# broadcasted operations seem not to be supported by mypy, so we ignore the type
256258
new_average = (new_value + count * previous_avg) / (count + 1) # type: ignore

src/pydvl/utils/parallel/actor.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import abc
22
import inspect
33
import logging
4-
from typing import Any, Dict, Optional, Type, Union, cast
4+
from time import sleep
5+
from typing import Generic, List, Optional, Type, TypeVar, cast
56

6-
from pydvl.utils.config import ParallelConfig
7-
from pydvl.utils.parallel.backend import RayParallelBackend, init_parallel_backend
7+
from ..config import ParallelConfig
8+
from ..status import Status
9+
from .backend import RayParallelBackend, init_parallel_backend
810

911
__all__ = ["RayActorWrapper", "Coordinator", "Worker"]
1012

@@ -65,50 +67,45 @@ def wrapper(
6567
setattr(self, name, remote_caller(name))
6668

6769

68-
class Coordinator(abc.ABC):
70+
Result = TypeVar("Result") # Avoids circular import with ValuationResult
71+
72+
73+
class Coordinator(Generic[Result], abc.ABC):
6974
"""The coordinator has two main tasks: aggregating the results of the
7075
workers and terminating the process once a certain accuracy or total
7176
number of iterations is reached.
72-
73-
:param progress: Whether to display a progress bar
7477
"""
7578

76-
def __init__(self, *, progress: Optional[bool] = True):
77-
self.progress = progress
78-
# For each worker: values, stddev, num_iterations
79-
self.workers_results: Dict[int, Dict[str, float]] = dict()
80-
self._total_iterations = 0
81-
self._is_done = False
79+
_status: Status
8280

83-
def add_results(self, worker_id: int, results: Dict[str, Union[float, int]]):
81+
def __init__(self):
82+
self.worker_results: List[Result] = []
83+
self._status = Status.Pending
84+
85+
def add_results(self, results: Result):
8486
"""Used by workers to report their results. Stores the results directly
85-
into the `worker_status` dictionary.
87+
into :attr:`worker_results`
8688
87-
:param worker_id: id of the worker
88-
:param results: results of worker calculations
89+
:param results: results of worker's calculations
8990
"""
90-
self.workers_results[worker_id] = results
91+
self.worker_results.append(results)
9192

9293
# this should be a @property, but with it ray.get messes up
9394
def is_done(self) -> bool:
9495
"""Used by workers to check whether to terminate their process.
9596
96-
:return: `True` if workers must terminate, `False` otherwise.
97+
:return: ``True`` if workers must terminate, ``False`` otherwise.
9798
"""
98-
return self._is_done
99+
return bool(self._status)
99100

100101
@abc.abstractmethod
101-
def get_results(self) -> Any:
102+
def accumulate(self) -> Result:
102103
"""Aggregates the results of the different workers."""
103104
raise NotImplementedError()
104105

105106
@abc.abstractmethod
106-
def check_done(self) -> bool:
107-
"""Checks whether the accuracy of the calculation or the total number
108-
of iterations have crossed the set thresholds.
109-
110-
If so, it sets the `is_done` label to `True`.
111-
"""
107+
def check_convergence(self) -> bool:
108+
"""Evaluates the convergence criteria on the aggregated results."""
112109
raise NotImplementedError()
113110

114111

@@ -117,25 +114,22 @@ class Worker(abc.ABC):
117114

118115
def __init__(
119116
self,
120-
coordinator: "Coordinator",
117+
coordinator: Coordinator,
121118
worker_id: int,
122119
*,
123-
progress: bool = False,
124120
update_period: int = 30,
125121
):
126122
"""A worker
127123
128124
:param coordinator: worker results will be pushed to this coordinator
129125
:param worker_id: id used for reporting through maybe_progress
130-
:param progress: set to True to report progress, else False
131126
:param update_period: interval in seconds between different updates
132127
to and from the coordinator
133128
"""
134129
super().__init__()
135130
self.worker_id = worker_id
136131
self.coordinator = coordinator
137132
self.update_period = update_period
138-
self.progress = progress
139133

140134
def run(self, *args, **kwargs):
141135
"""Runs the worker."""

src/pydvl/utils/parallel/map_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
from collections.abc import Iterable
32
from functools import singledispatch, update_wrapper
43
from itertools import accumulate, repeat
54
from typing import (
@@ -70,6 +69,7 @@ def _(v: np.ndarray, *, timeout: Optional[float] = None) -> NDArray:
7069
return v
7170

7271

72+
# Careful to use list as hint. The dispatch does not work with typing generics
7373
@_get_value.register
7474
def _(v: list, *, timeout: Optional[float] = None) -> List[Any]:
7575
return [_get_value(x, timeout=timeout) for x in v]

0 commit comments

Comments
 (0)