Skip to content

Commit 82312d7

Browse files
Properly document get_test_data() method of Dataset
1 parent 074b6b2 commit 82312d7

File tree

1 file changed

+59
-13
lines changed

1 file changed

+59
-13
lines changed

src/pydvl/utils/dataset.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""
22
This module contains convenience classes to handle data and groups thereof.
33
4-
Shapley value computations require evaluation of a scoring function (the
5-
*utility*). This is typically the performance of the model on a test set (as an
6-
approximation to its true expected performance). It is therefore convenient to
7-
keep both the training data and the test data together to be passed around to
8-
methods in :mod:`~pydvl.value.shapley`. This is done with
9-
:class:`~pydvl.utils.dataset.Dataset`.
4+
Shapley and Least Core value computations require evaluation of a scoring function
5+
(the *utility*). This is typically the performance of the model on a test set
6+
(as an approximation to its true expected performance). It is therefore convenient
7+
to keep both the training data and the test data together to be passed around to
8+
methods in :mod:`~pydvl.value.shapley` and :mod:`~pydvl.value.least_core`.
9+
This is done with :class:`~pydvl.utils.dataset.Dataset`.
1010
1111
This abstraction layer also seamlessly grouping data points together if one is
1212
interested in computing their value as a group, see
@@ -143,26 +143,72 @@ def feature(self, name: str) -> Tuple[slice, int]:
143143
raise ValueError(f"Feature {name} is not in {self.feature_names}")
144144

145145
def get_training_data(
146-
self, indices: Optional[Iterable[int]]
146+
self, indices: Optional[Iterable[int]] = None
147147
) -> Tuple[NDArray, NDArray]:
148148
"""Given a set of indices, returns the training data that refer to those
149149
indices.
150150
151151
This is used when calling different sub-sets of indices to calculate
152152
shapley values. Notice that train_indices is not typically equal to the
153153
full indices, but only a subset of it.
154+
155+
:param indices: Optional indices that will be used
156+
to select data points from the training data.
157+
:return: If indices is not None, the selected x and y arrays from
158+
the training data. Otherwise, the entire training data.
154159
"""
155160
if indices is None:
156161
return self.x_train, self.y_train
157-
else:
158-
x = self.x_train[indices]
159-
y = self.y_train[indices]
160-
return x, y
162+
x = self.x_train[indices]
163+
y = self.y_train[indices]
164+
return x, y
161165

162166
def get_test_data(
163-
self, indices: Optional[Iterable[int]]
167+
self, indices: Optional[Iterable[int]] = None
164168
) -> Tuple[NDArray, NDArray]:
165-
"""Returns the entire test set regardless of the passed indices."""
169+
"""Returns the entire test set regardless of the passed indices.
170+
171+
The passed indices will not be used because for data valuation
172+
we generally want to score the trained model on the entire test data.
173+
174+
Additionally, the way this method is used in the
175+
:class:`~pydvl.utils.utility.Utility` class, the passed indices will
176+
be those of the training data and would not work on the test data.
177+
178+
There may be cases where it is desired to use parts of the test data.
179+
In those cases, it is recommended to inherit from the :class:`Dataset`
180+
class and to override the :meth:`~Dataset.get_test_data` method.
181+
182+
For example, the following snippet shows how one could go about
183+
mapping the training data indices into test data indices
184+
inside :meth:`~Dataset.get_test_data`:
185+
186+
:Example:
187+
188+
>>> from pydvl.utils import Dataset
189+
>>> import numpy as np
190+
>>> class DatasetWithTestDataIndices(Dataset):
191+
... def get_test_data(self, indices=None):
192+
... if indices is None:
193+
... return self.x_test, self.y_test
194+
... fraction = len(list(indices)) / len(self)
195+
... mapped_indices = len(self.x_test) / len(self) * np.asarray(indices)
196+
... mapped_indices = np.unique(mapped_indices.astype(int))
197+
... return self.x_test[mapped_indices], self.y_test[mapped_indices]
198+
...
199+
>>> X = np.random.rand(100, 10)
200+
>>> y = np.random.randint(0, 2, 100)
201+
>>> dataset = DatasetWithTestDataIndices.from_arrays(X, y)
202+
>>> indices = np.random.choice(dataset.indices, 30, replace=False)
203+
>>> _ = dataset.get_training_data(indices)
204+
>>> _ = dataset.get_test_data(indices)
205+
206+
207+
:param indices: Optional indices into the test data. This argument
208+
is unused and is left as is to keep the same interface as
209+
:meth:`Dataset.get_training_data`.
210+
:return: The entire test data.
211+
"""
166212
return self.x_test, self.y_test
167213

168214
def target(self, name: str) -> Tuple[slice, int]:

0 commit comments

Comments
 (0)