Skip to content

Commit 79e616a

Browse files
authored
Merge pull request #267 from appliedAI-Initiative/fix/missing-tests
Fix parallelization for Owen, improve and add tests, and some fixes
2 parents 783b0a0 + 285c8ec commit 79e616a

File tree

14 files changed

+348
-134
lines changed

14 files changed

+348
-134
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
## Unreleased
44

5-
- Added `Scorer` class for a cleaner interface. Fix minor bugs around
6-
Group-Testing Shapley and switched to cvxpy for the constraint solver.
5+
- Fixed parallel and antithetic Owen sampling for Shapley values. Simplified
6+
and extended tests.
7+
[PR #267](https://github.com/appliedAI-Initiative/pyDVL/pull/267)
8+
- Added `Scorer` class for a cleaner interface. Fixed minor bugs around
9+
Group-Testing Shapley, added more tests and switched to cvxpy for the solver.
710
[PR #264](https://github.com/appliedAI-Initiative/pyDVL/pull/264)
811
- Generalised stopping criteria for valuation algorithms. Improved classes
912
`ValuationResult` and `Status` with more operations. Some minor issues fixed.

src/pydvl/utils/score.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html>`_.
99
1010
:class:`Scorer` provides additional information about the scoring function, like
11-
its range and default values.
11+
its range and default values, which can be used by some data valuation
12+
methods (like :func:`~pydvl.value.shapley.gt.group_testing_shapley`) to estimate
13+
the number of samples required for a certain quality of approximation.
1214
"""
1315
from typing import Callable, Optional, Protocol, Tuple, Union
1416

@@ -60,6 +62,8 @@ def __init__(
6062
range: Tuple = (-np.inf, np.inf),
6163
name: Optional[str] = None,
6264
):
65+
if name is None and isinstance(scoring, str):
66+
name = scoring
6367
self._scorer = get_scorer(scoring)
6468
self.default = default
6569
# TODO: auto-fill from known scorers ?
@@ -102,12 +106,12 @@ def compose_score(
102106
:return: The composite :class:`Scorer`.
103107
"""
104108

105-
class NewScorer(Scorer):
109+
class CompositeScorer(Scorer):
106110
def __call__(self, model: SupervisedModel, X: NDArray, y: NDArray) -> float:
107111
score = self._scorer(model=model, X=X, y=y)
108112
return transformation(score)
109113

110-
return NewScorer(scorer, range=range, name=name)
114+
return CompositeScorer(scorer, range=range, name=name)
111115

112116

113117
def _sigmoid(x: float) -> float:

src/pydvl/value/result.py

Lines changed: 116 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
to raw values, as well as convenient behaviour as a ``Sequence`` with extended
77
indexing and updating abilities, and conversion to `pandas DataFrames
88
<https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html>`_.
9+
10+
.. rubric:: Operating on results
11+
12+
Results can be added together with the standard ``+`` operator. Because values
13+
are typically running averages of iterative algorithms, addition behaves like a
14+
weighted average of the two results, with the weights being the number of
15+
updates in each result: adding two results is the same as generating one result
16+
with the mean of the values of the two results as values. The variances are
17+
updated accordingly. See :class:`ValuationResult` for details.
18+
19+
Results can also be sorted by value, variance or number of updates, see
20+
:meth:`ValuationResult.sort`. The arrays of :attr:`ValuationResult.values`,
21+
:attr:`ValuationResult.variances`, :attr:`ValuationResult.counts`,
22+
:attr:`ValuationResult.indices` and :attr:`ValuationResult.names` are sorted in
23+
the same way.
24+
25+
Indexing and slicing of results is supported and :class:`ValueItem` objects are
26+
returned. These objects can be compared with the usual operators, which take
27+
only the :attr:`ValueItem.value` into account.
28+
929
"""
1030
import collections.abc
1131
import logging
@@ -37,7 +57,7 @@
3757
except ImportError:
3858
pass
3959

40-
__all__ = ["ValuationResult"]
60+
__all__ = ["ValuationResult", "ValueItem"]
4161

4262
logger = logging.getLogger(__name__)
4363

@@ -47,15 +67,16 @@
4767
class ValueItem:
4868
"""The result of a value computation for one datum.
4969
50-
ValueItems can be compared with the usual operators. These take only the
51-
:attribute:`value` into account
70+
``ValueItems`` can be compared with the usual operators, forming a total
71+
order. Comparisons take only the :attr:`value` into account.
5272
5373
.. todo::
5474
Maybe have a mode of comparing similar to `np.isclose`, or taking the
55-
:attribute:`variance` into account.
75+
:attr:`variance` into account.
5676
"""
5777

58-
#: Index of the sample with this value in the original :class:`Dataset`
78+
#: Index of the sample with this value in the original
79+
# :class:`~pydvl.utils.dataset.Dataset`
5980
index: int
6081
#: Name of the sample if it was provided. Otherwise, `str(index)`
6182
name: str
@@ -88,18 +109,34 @@ class ValuationResult(collections.abc.Sequence):
88109
89110
These include indices in the original :class:`Dataset`, any data names (e.g.
90111
group names in :class:`GroupedDataset`), the values themselves, and variance
91-
of the computation in the case of Monte Carlo methods. These can iterated
92-
over like any ``Sequence``: ``iter(valuation_result)`` returns a generator
93-
of :class:`ValueItem` in the order in which the object is sorted.
112+
of the computation in the case of Monte Carlo methods. ``ValuationResults``
113+
can be iterated over like any ``Sequence``: ``iter(valuation_result)``
114+
returns a generator of :class:`ValueItem` in the order in which the object
115+
is sorted.
116+
117+
.. rubric:: Indexing
118+
119+
Indexing can be position-based, when accessing any of the attributes
120+
:attr:`values`, :attr:`variances`, :attr:`counts` and :attr:`indices`, as
121+
well as when iterating over the object, or using the item access operator,
122+
both getter and setter. The "position" is either the original sequence in
123+
which the data was passed to the constructor, or the sequence in which the
124+
object is sorted, see below.
125+
126+
Alternatively, indexing can be data-based, i.e. using the indices in the
127+
original dataset. This is the case for the methods :meth:`get` and
128+
:meth:`update`.
94129
95130
.. rubric:: Sorting
96131
97-
Results can be sorted in-place with :meth:`sort` or using python's standard
98-
``sorted()`` and ``reversed()`` Note that sorting values affects how
99-
iterators and the object itself as ``Sequence`` behave: ``values[0]``
100-
returns a :class:`ValueItem` with the highest or lowest ranking point if
101-
this object is sorted by descending or ascending value, respectively. If
102-
unsorted, ``values[0]`` returns a ``ValueItem`` for index 0.
132+
Results can be sorted in-place with :meth:`sort`, or alternatively using
133+
python's standard ``sorted()`` and ``reversed()`` Note that sorting values
134+
affects how iterators and the object itself as ``Sequence`` behave:
135+
``values[0]`` returns a :class:`ValueItem` with the highest or lowest
136+
ranking point if this object is sorted by descending or ascending value,
137+
respectively. If unsorted, ``values[0]`` returns the ``ValueItem`` at
138+
position 0, which has data index ``indices[0]`` in the
139+
:class:`~pydvl.utils.dataset.Dataset`.
103140
104141
The same applies to direct indexing of the ``ValuationResult``: the index
105142
is positional, according to the sorting. It does not refer to the "data
@@ -175,7 +212,7 @@ def __init__(
175212
raise ValueError("Lengths of values and indices do not match")
176213

177214
self._algorithm = algorithm
178-
self._status = status
215+
self._status = Status(status) # Just in case we are given a string
179216
self._values = values
180217
self._variances = np.zeros_like(values) if variances is None else variances
181218
self._counts = np.ones_like(values) if counts is None else counts
@@ -189,20 +226,25 @@ def __init__(
189226
if indices is None:
190227
indices = np.arange(len(self._values), dtype=np.int_)
191228
self._indices = indices
229+
self._positions = {idx: pos for pos, idx in enumerate(indices)}
192230

193-
self._sort_indices = np.arange(len(self._values), dtype=np.int_)
231+
self._sort_positions = np.arange(len(self._values), dtype=np.int_)
194232
if sort:
195233
self.sort()
196234

197235
def sort(
198236
self,
199237
reverse: bool = False,
200238
# Need a "Comparable" type here
201-
key: Literal["value", "index", "name"] = "value",
239+
key: Literal["value", "variance", "index", "name"] = "value",
202240
) -> None:
203-
"""Sorts the indices in place by ascending value.
241+
"""Sorts the indices in place by ``key``.
204242
205-
Once sorted, iteration over the results will follow the order.
243+
Once sorted, iteration over the results, and indexing of all the
244+
properties :attr:`ValuationResult.values`,
245+
:attr:`ValuationResult.variances`, :attr:`ValuationResult.counts`,
246+
:attr:`ValuationResult.indices` and :attr:`ValuationResult.names` will
247+
follow the same order.
206248
207249
:param reverse: Whether to sort in descending order by value.
208250
:param key: The key to sort by. Defaults to :attr:`ValueItem.value`.
@@ -213,20 +255,20 @@ def sort(
213255
"variance": "_variances",
214256
"name": "_names",
215257
}
216-
self._sort_indices = np.argsort(getattr(self, keymap[key]))
258+
self._sort_positions = np.argsort(getattr(self, keymap[key]))
217259
if reverse:
218-
self._sort_indices = self._sort_indices[::-1]
260+
self._sort_positions = self._sort_positions[::-1]
219261
self._sort_order = reverse
220262

221263
@property
222264
def values(self) -> NDArray[np.float_]:
223265
"""The values, possibly sorted."""
224-
return self._values[self._sort_indices]
266+
return self._values[self._sort_positions]
225267

226268
@property
227269
def variances(self) -> NDArray[np.float_]:
228270
"""The variances, possibly sorted."""
229-
return self._variances[self._sort_indices]
271+
return self._variances[self._sort_positions]
230272

231273
@property
232274
def stderr(self) -> NDArray[np.float_]:
@@ -238,7 +280,7 @@ def stderr(self) -> NDArray[np.float_]:
238280
@property
239281
def counts(self) -> NDArray[np.int_]:
240282
"""The raw counts, possibly sorted."""
241-
return self._counts[self._sort_indices]
283+
return self._counts[self._sort_positions]
242284

243285
@property
244286
def indices(self) -> NDArray[np.int_]:
@@ -247,15 +289,15 @@ def indices(self) -> NDArray[np.int_]:
247289
If the object is unsorted, then these are the same as declared at
248290
construction or ``np.arange(len(values))`` if none were passed.
249291
"""
250-
return self._indices[self._sort_indices]
292+
return self._indices[self._sort_positions]
251293

252294
@property
253295
def names(self) -> NDArray[np.str_]:
254296
"""The names for the values, possibly sorted.
255297
If the object is unsorted, then these are the same as declared at
256298
construction or ``np.arange(len(values))`` if none were passed.
257299
"""
258-
return self._names[self._sort_indices]
300+
return self._names[self._sort_positions]
259301

260302
@property
261303
def status(self) -> Status:
@@ -301,7 +343,7 @@ def __getitem__(
301343
key += len(self)
302344
if key < 0 or int(key) >= len(self):
303345
raise IndexError(f"Index {key} out of range (0, {len(self)}).")
304-
idx = self._sort_indices[key]
346+
idx = self._sort_positions[key]
305347
return ValueItem(
306348
int(self._indices[idx]),
307349
str(self._names[idx]),
@@ -338,26 +380,26 @@ def __setitem__(
338380
key += len(self)
339381
if key < 0 or int(key) >= len(self):
340382
raise IndexError(f"Index {key} out of range (0, {len(self)}).")
341-
idx = self._sort_indices[key]
342-
self._indices[idx] = value.index
343-
self._names[idx] = value.name
344-
self._values[idx] = value.value
345-
self._variances[idx] = value.variance
346-
self._counts[idx] = value.count
383+
pos = self._sort_positions[key]
384+
self._indices[pos] = value.index
385+
self._names[pos] = value.name
386+
self._values[pos] = value.value
387+
self._variances[pos] = value.variance
388+
self._counts[pos] = value.count
347389
else:
348390
raise TypeError("Indices must be integers, iterable or slices")
349391

350392
def __iter__(self) -> Generator[ValueItem, Any, None]:
351393
"""Iterate over the results returning :class:`ValueItem` objects.
352394
To sort in place before iteration, use :meth:`sort`.
353395
"""
354-
for idx in self._sort_indices:
396+
for pos in self._sort_positions:
355397
yield ValueItem(
356-
self._indices[idx],
357-
self._names[idx],
358-
self._values[idx],
359-
self._variances[idx],
360-
self._counts[idx],
398+
self._indices[pos],
399+
self._names[pos],
400+
self._values[pos],
401+
self._variances[pos],
402+
self._counts[pos],
361403
)
362404

363405
def __len__(self):
@@ -434,8 +476,8 @@ def __add__(self, other: "ValuationResult") -> "ValuationResult":
434476
self._check_compatible(other)
435477

436478
indices = np.union1d(self._indices, other._indices)
437-
this_indices = np.searchsorted(indices, self._indices)
438-
other_indices = np.searchsorted(indices, other._indices)
479+
this_pos = np.searchsorted(indices, self._indices)
480+
other_pos = np.searchsorted(indices, other._indices)
439481

440482
n = np.zeros_like(indices, dtype=int)
441483
m = np.zeros_like(indices, dtype=int)
@@ -444,12 +486,12 @@ def __add__(self, other: "ValuationResult") -> "ValuationResult":
444486
vn = np.zeros_like(indices, dtype=float)
445487
vm = np.zeros_like(indices, dtype=float)
446488

447-
n[this_indices] = self._counts
448-
xn[this_indices] = self._values
449-
vn[this_indices] = self._variances
450-
m[other_indices] = other._counts
451-
xm[other_indices] = other._values
452-
vm[other_indices] = other._variances
489+
n[this_pos] = self._counts
490+
xn[this_pos] = self._values
491+
vn[this_pos] = self._variances
492+
m[other_pos] = other._counts
493+
xm[other_pos] = other._values
494+
vm[other_pos] = other._variances
453495

454496
# Sample mean of n+m samples from two means of n and m samples
455497
xnm = (n * xn + m * xm) / (n + m)
@@ -458,8 +500,8 @@ def __add__(self, other: "ValuationResult") -> "ValuationResult":
458500

459501
this_names = np.empty_like(indices, dtype=np.str_)
460502
other_names = np.empty_like(indices, dtype=np.str_)
461-
this_names[this_indices] = self._names
462-
other_names[other_indices] = other._names
503+
this_names[this_pos] = self._names
504+
other_names[other_pos] = other._names
463505
names = np.where(n > 0, this_names, other_names)
464506
both = np.where((n > 0) & (m > 0))
465507
if np.any(other_names[both] != this_names[both]):
@@ -489,24 +531,38 @@ def update(self, idx: int, new_value: float) -> "ValuationResult":
489531
"""Updates the result in place with a new value, using running mean
490532
and variance.
491533
492-
:param idx: Index of the value to update.
534+
:param idx: Data index of the value to update.
493535
:param new_value: New value to add to the result.
494536
:return: A reference to the same, modified result.
537+
:raises IndexError: If the index is not found.
495538
"""
539+
try:
540+
pos = self._positions[idx]
541+
except KeyError:
542+
raise IndexError(f"Index {idx} not found in ValuationResult")
496543
val, var = running_moments(
497-
self._values[idx],
498-
self._variances[idx],
499-
self._counts[idx],
500-
new_value,
544+
self._values[pos], self._variances[pos], self._counts[pos], new_value
501545
)
502-
self[idx] = ValueItem(idx, self._names[idx], val, var, self._counts[idx] + 1)
546+
self[pos] = ValueItem(idx, self._names[pos], val, var, self._counts[pos] + 1)
503547
return self
504548

505549
def get(self, idx: Integral) -> ValueItem:
506550
"""Retrieves a ValueItem by data index, as opposed to sort index, like
507551
the indexing operator.
552+
:raises IndexError: If the index is not found.
508553
"""
509-
raise NotImplementedError()
554+
try:
555+
pos = self._positions[idx]
556+
except KeyError:
557+
raise IndexError(f"Index {idx} not found in ValuationResult")
558+
559+
return ValueItem(
560+
self._indices[pos],
561+
self._names[pos],
562+
self._values[pos],
563+
self._variances[pos],
564+
self._counts[pos],
565+
)
510566

511567
def to_dataframe(
512568
self, column: Optional[str] = None, use_names: bool = False
@@ -526,13 +582,13 @@ def to_dataframe(
526582
raise ImportError("Pandas required for DataFrame export")
527583
column = column or self._algorithm
528584
df = pandas.DataFrame(
529-
self._values[self._sort_indices],
530-
index=self._names[self._sort_indices]
585+
self._values[self._sort_positions],
586+
index=self._names[self._sort_positions]
531587
if use_names
532-
else self._indices[self._sort_indices],
588+
else self._indices[self._sort_positions],
533589
columns=[column],
534590
)
535-
df[column + "_stderr"] = self.stderr[self._sort_indices]
591+
df[column + "_stderr"] = self.stderr[self._sort_positions]
536592
return df
537593

538594
@classmethod

0 commit comments

Comments
 (0)