Skip to content

Commit 1949fcc

Browse files
authored
Merge pull request #338 from aai-institute/259-implement-class-wise-shapley
Implement class wise shapley
2 parents 4c10cc3 + bebfd9c commit 1949fcc

29 files changed

+69481
-59
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- New method: Class-wise Shapley values
6+
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)
57
- No longer using docker within tests to start a memcached server
68
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
79
- Faster semi-value computation with per-index check of stopping criteria (optional)
@@ -43,6 +45,9 @@ randomness.
4345
`compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and
4446
`compute_generic_semivalues`.
4547
[PR #428](https://github.com/aai-institute/pyDVL/pull/428)
48+
- Added classwise Shapley as proposed by (Schoch et al. 2021)
49+
[https://arxiv.org/abs/2211.06800]
50+
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)
4651

4752
### Changed
4853

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ methods from the following papers:
7171
Efficient Data Value](https://proceedings.mlr.press/v202/kwon23e.html). In
7272
Proceedings of the 40th International Conference on Machine Learning, 18135–52.
7373
PMLR, 2023.
74+
- Schoch, Stephanie, Haifeng Xu, and Yangfeng Ji. [CS-Shapley: Class-Wise
75+
Shapley Values for Data Valuation in
76+
Classification](https://openreview.net/forum?id=KTOcrOR5mQ9). In Proc. of the
77+
Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS).
78+
New Orleans, Louisiana, USA, 2022.
7479

7580
Influence Functions compute the effect that single points have on an estimator /
7681
model. We implement methods from the following papers:

docs/api/pydvl/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg

Lines changed: 68001 additions & 0 deletions
Loading

docs/value/classwise-shapley.md

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
---
2+
title: Class-wise Shapley
3+
---
4+
5+
# Class-wise Shapley
6+
7+
Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework
8+
tailored for classification problems. Given a sample $x_i$ with label $y_i \in
9+
\mathbb{N}$, let $D_{y_i}$ be the subset of $D$ with labels $y_i$, and
10+
$D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that the
11+
sample $(x_i, y_i)$ might improve the overall model performance on $D$, while
12+
being detrimental for the performance on $D_{y_i},$ e.g. because of a wrong
13+
label. To address this issue, the authors introduced
14+
15+
$$
16+
v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
17+
\left [
18+
\frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
19+
\delta(S_{y_i} | S_{-y_i})
20+
\right ],
21+
$$
22+
23+
where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
24+
D_{-y_i}$ is _arbitrary_ (in particular, not the complement of $S_{y_i}$). The
25+
function $\delta$ is called **set-conditional marginal Shapley value** and is
26+
defined as
27+
28+
$$
29+
\delta(S | C) = u( S_{+i} | C ) − u(S | C),
30+
$$
31+
32+
for any set $S$ such that $i \notin S, C$ and $S \cap C = \emptyset$.
33+
34+
In practical applications, estimating this quantity is done both with Monte
35+
Carlo sampling of the powerset, and the set of index permutations
36+
[@castro_polynomial_2009]. Typically, this requires fewer samples than the
37+
original Shapley value, although the actual speed-up depends on the model and
38+
the dataset.
39+
40+
41+
!!! Example "Computing classwise Shapley values"
42+
Like all other game-theoretic valuation methods, CWS requires a
43+
[Utility][pydvl.utils.utility.Utility] object constructed with model and
44+
dataset, with the peculiarity of requiring a specific
45+
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. The entry
46+
point is the function
47+
[compute_classwise_shapley_values][pydvl.value.shapley.classwise.compute_classwise_shapley_values]:
48+
49+
```python
50+
from pydvl.value import *
51+
52+
model = ...
53+
data = Dataset(...)
54+
scorer = ClasswiseScorer(...)
55+
utility = Utility(model, data, scorer)
56+
values = compute_classwise_shapley_values(
57+
utility,
58+
done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
59+
truncation=RelativeTruncation(utility, rtol=0.01),
60+
done_sample_complements=MaxChecks(1),
61+
normalize_values=True
62+
)
63+
```
64+
65+
66+
### The class-wise scorer
67+
68+
In order to use the classwise Shapley value, one needs to define a
69+
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. This scorer
70+
is defined as
71+
72+
$$
73+
u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})),
74+
$$
75+
76+
where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the
77+
**in-class accuracy**, and $a_S(D_{-y_i})$ is the **out-of-class accuracy** (the
78+
names originate from a choice by the authors to use accuracy, but in principle
79+
any other score, like $F_1$ can be used).
80+
81+
The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
82+
therefore the defaults, but we leave the option to set different functions $f$
83+
and $g$ for an exploration with different base scores.
84+
85+
!!! Example "The default class-wise scorer"
86+
Constructing the CWS scorer requires choosing a metric and the functions $f$
87+
and $g$:
88+
89+
```python
90+
import numpy as np
91+
from pydvl.value.shapley.classwise import ClasswiseScorer
92+
93+
# These are the defaults
94+
identity = lambda x: x
95+
scorer = ClasswiseScorer(
96+
"accuracy",
97+
in_class_discount_fn=identity,
98+
out_of_class_discount_fn=np.exp
99+
)
100+
```
101+
102+
??? "Surface of the discounted utility function"
103+
The level curves for $f(x)=x$ and $g(x)=e^x$ are depicted below. The lines
104+
illustrate the contour lines, annotated with their respective gradients.
105+
![Level curves of the class-wise
106+
utility](img/classwise-shapley-discounted-utility-function.svg){ align=left width=33% class=invertible }
107+
108+
## Evaluation
109+
110+
We illustrate the method with two experiments: point removal and noise removal,
111+
as well as an analysis of the distribution of the values. For this we employ the
112+
nine datasets used in [@schoch_csshapley_2022], using the same pre-processing.
113+
For images, PCA is used to reduce down to 32 the features found by a pre-trained
114+
`Resnet18` model. Standard loc-scale normalization is performed for all models
115+
except gradient boosting, since the latter is not sensitive to the scale of the
116+
features.
117+
118+
??? info "Datasets used for evaluation"
119+
| Dataset | Data Type | Classes | Input Dims | OpenML ID |
120+
|----------------|-----------|---------|------------|-----------|
121+
| Diabetes | Tabular | 2 | 8 | 37 |
122+
| Click | Tabular | 2 | 11 | 1216 |
123+
| CPU | Tabular | 2 | 21 | 197 |
124+
| Covertype | Tabular | 7 | 54 | 1596 |
125+
| Phoneme | Tabular | 2 | 5 | 1489 |
126+
| FMNIST | Image | 2 | 32 | 40996 |
127+
| CIFAR10 | Image | 2 | 32 | 40927 |
128+
| MNIST (binary) | Image | 2 | 32 | 554 |
129+
| MNIST (multi) | Image | 10 | 32 | 554 |
130+
131+
We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner
132+
metric". The former shows the performance of the method, whereas the latter
133+
displays its stability: we normalize by the mean to see the relative effect of
134+
the standard deviation. Ideally the mean value is maximal and CV minimal.
135+
136+
Finally, we note that for all sampling-based valuation methods the same number
137+
of _evaluations of the marginal utility_ was used. This is important to make the
138+
algorithms comparable, but in practice one should consider using a more
139+
sophisticated stopping criterion.
140+
141+
### Dataset pruning for logistic regression (point removal)
142+
143+
In (best-)point removal, one first computes values for the training set and then
144+
removes in sequence the points with the highest values. After each removal, the
145+
remaining points are used to train the model from scratch and performance is
146+
measured on a test set. This produces a curve of performance vs. number of
147+
points removed which we show below.
148+
149+
As a scalar summary of this curve, [@schoch_csshapley_2022] define **Weighted
150+
Accuracy Drop** (WAD) as:
151+
152+
$$
153+
\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j}
154+
a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right)
155+
= a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} ,
156+
$$
157+
158+
where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$
159+
and $T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1, \dots , j
160+
\}$.
161+
162+
We run the point removal experiment for a logistic regression model five times
163+
and compute WAD for each run, then report the mean $\mu_\text{WAD}$ and standard
164+
deviation $\sigma_\text{WAD}$.
165+
166+
![Mean WAD for best-point removal on logistic regression. Values
167+
computed using LOO, CWS, Beta Shapley, and TMCS
168+
](img/classwise-shapley-metric-wad-mean.svg){ class=invertible }
169+
170+
We see that CWS is competitive with all three other methods. In all problems
171+
except `MNIST (multi)` it outperforms TMCS, while in that case TMCS has a slight
172+
advantage.
173+
174+
In order to understand the variability of WAD we look at its coefficient of
175+
variation (lower is better):
176+
177+
![Coefficient of Variation of WAD for best-point removal on logistic regression.
178+
Values computed using LOO, CWS, Beta Shapley, and TMCS
179+
](img/classwise-shapley-metric-wad-cv.svg){ class=invertible }
180+
181+
CWS is not the best method in terms of CV. For `CIFAR10`, `Click`, `CPU` and
182+
`MNIST (binary)` Beta Shapley has the lowest CV. For `Diabetes`, `MNIST (multi)`
183+
and `Phoneme` CWS is the winner and for `FMNIST` and `Covertype` TMCS takes the
184+
lead. Besides LOO, TMCS has the highest relative standard deviation.
185+
186+
The following plot shows accuracy vs number of samples removed. Random values
187+
serve as a baseline. The shaded area represents the 95% bootstrap confidence
188+
interval of the mean across 5 runs.
189+
190+
![Accuracy after best-sample removal using values from logistic
191+
regression](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg){ class=invertible }
192+
193+
Because samples are removed from high to low valuation order, we expect a steep
194+
decrease in the curve.
195+
196+
Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with
197+
CWS's CV on par with Beta Shapley's, making CWS a competitive method.
198+
199+
200+
### Dataset pruning for a neural network by value transfer
201+
202+
Transfer of values from one model to another is probably of greater practical
203+
relevance: values are computed using a cheap model and used to prune the dataset
204+
before training a more expensive one.
205+
206+
The following plot shows accuracy vs number of samples removed for transfer from
207+
logistic regression to a neural network. The shaded area represents the 95%
208+
bootstrap confidence interval of the mean across 5 runs.
209+
210+
![Accuracy after sample removal using values transferred from logistic
211+
regression to an MLP
212+
](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg){ class=invertible }
213+
214+
As in the previous experiment samples are removed from high to low valuation
215+
order and hence we expect a steep decrease in the curve. CWS is competitive with
216+
the other methods, especially in very unbalanced datasets like `Click`. In other
217+
datasets, like `Covertype`, `Diabetes` and `MNIST (multi)` the performance is on
218+
par with TMCS.
219+
220+
221+
### Detection of mis-labeled data points
222+
223+
The next experiment tries to detect mis-labeled data points in binary
224+
classification tasks. 20% of the indices is flipped at random (we don't consider
225+
multi-class datasets because there isn't a unique flipping strategy). The
226+
following table shows the mean of the area under the curve (AUC) for five runs.
227+
228+
![Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS,
229+
Beta Shapley, and
230+
TMCS](img/classwise-shapley-metric-auc-mean.svg){ class=invertible }
231+
232+
In the majority of cases TMCS has a slight advantage over CWS, except for
233+
`Click`, where CWS has a slight edge, most probably due to the unbalanced nature
234+
of the dataset. The following plot shows the CV for the AUC of the five runs.
235+
236+
![Coefficient of variation of AUC for mis-labeled data point detection. Values
237+
computed using LOO, CWS, Beta Shapley, and TMCS
238+
](img/classwise-shapley-metric-auc-cv.svg){ class=invertible }
239+
240+
In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.
241+
242+
Finally, we look at the ROC curves training the classifier on the $n$ first
243+
samples in _increasing_ order of valuation (i.e. starting with the worst):
244+
245+
![Mean ROC across 5 runs with 95% bootstrap
246+
CI](img/classwise-shapley-roc-auc-logistic-regression.svg){ class=invertible }
247+
248+
Although at first sight TMCS seems to be the winner, CWS stays competitive after
249+
factoring in running time. For a perfectly balanced dataset, CWS needs on
250+
average fewer samples than TCMS.
251+
252+
### Value distribution
253+
254+
For illustration, we compare the distribution of values computed by TMCS and
255+
CWS.
256+
257+
![Histogram and estimated density of the values computed by TMCS and
258+
CWS on all nine datasets](img/classwise-shapley-density.svg){ class=invertible }
259+
260+
For `Click` TMCS has a multi-modal distribution of values. We hypothesize that
261+
this is due to the highly unbalanced nature of the dataset, and notice that CWS
262+
has a single mode, leading to its greater performance on this dataset.
263+
264+
## Conclusion
265+
266+
CWS is an effective way to handle classification problems, in particular for
267+
unbalanced datasets. It reduces the computing requirements by considering
268+
in-class and out-of-class points separately.
269+

docs/value/img/classwise-shapley-density.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/value/img/classwise-shapley-discounted-utility-function.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/value/img/classwise-shapley-metric-auc-cv.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/value/img/classwise-shapley-metric-auc-mean.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/value/img/classwise-shapley-metric-wad-cv.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/value/img/classwise-shapley-metric-wad-mean.svg

Lines changed: 1 addition & 0 deletions
Loading

0 commit comments

Comments
 (0)