Skip to content

Commit 5ee24bd

Browse files
committed
[skip ci] Bunch of fixes to the doc
1 parent 72d0105 commit 5ee24bd

File tree

2 files changed

+152
-133
lines changed

2 files changed

+152
-133
lines changed

docs/value/classwise-shapley.md

Lines changed: 148 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ title: Class-wise Shapley
55
# Class-wise Shapley
66

77
Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework
8-
tailored for classification problems. Let $D$ be a dataset, $D_{y_i}$ be the
9-
subset of $D$ with labels $y_i$, and $D_{-y_i}$ be the complement of $D_{y_i}$
10-
in $D$. The key idea is that a sample $(x_i, y_i)$ might enhance the overall
11-
performance on $D$, while being detrimental for the performance on $D_{y_i}$. To
12-
address this issue, the authors introduced
8+
tailored for classification problems. Given a sample $x_i$ with label $y_i \in
9+
\mathbb{N}$, let $D_{y_i}$ be the subset of $D$ with labels $y_i$, and
10+
$D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that the
11+
sample $(x_i, y_i)$ might improve the overall model performance on $D$, while
12+
being detrimental for the performance on $D_{y_i},$ e.g. because of a wrong
13+
label. To address this issue, the authors introduced
1314

1415
$$
1516
v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
@@ -20,14 +21,15 @@ v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
2021
$$
2122

2223
where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
23-
D_{-y_i}$, and the function $\delta$ is called **set-conditional marginal
24-
Shapley value**. It is defined as
24+
D_{-y_i}$ is _arbitrary_ (in particular, not the complement of $S_{y_i}$). The
25+
function $\delta$ is called **set-conditional marginal Shapley value** and is
26+
defined as
2527

2628
$$
27-
\delta(S | C) = u( S \cup \{i\} | C ) − u(S | C),
29+
\delta(S | C) = u( S_{+i} | C ) − u(S | C),
2830
$$
2931

30-
where $i \notin S, C$ and $S \bigcap C = \emptyset$.
32+
for any set $S$ such that $i \notin S, C$ and $S \cap C = \emptyset$.
3133

3234
In practical applications, estimating this quantity is done both with Monte
3335
Carlo sampling of the powerset, and the set of index permutations
@@ -36,7 +38,12 @@ original Shapley value, although the actual speed-up depends on the model and
3638
the dataset.
3739

3840

39-
??? Example "Computing classwise Shapley values"
41+
!!! Example "Computing classwise Shapley values"
42+
Like all other game-theoretic valuation methods, CWS requires a
43+
[Utility][pydvl.utils.utility.Utility] object constructed with model and
44+
dataset, with the peculiarity of requiring a specific
45+
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]:
46+
4047
```python
4148
from pydvl.value import *
4249

@@ -54,15 +61,14 @@ the dataset.
5461
```
5562

5663

57-
### Class-wise scorer
64+
### The class-wise scorer
5865

5966
In order to use the classwise Shapley value, one needs to define a
60-
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. Given a sample
61-
$x_i$ with label $y_i \in \mathbb{N}$, we define two disjoint sets $D_{y_i}$ and
62-
$D_{-y_i}$ and define
67+
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. This scorer
68+
is defined as
6369

6470
$$
65-
u(S) = f(a_S(D_{y_i}))) g(a_S(D_{-y_i}))),
71+
u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})),
6672
$$
6773

6874
where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the
@@ -74,7 +80,10 @@ The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
7480
therefore the defaults, but we leave the option to set different functions $f$
7581
and $g$ for an exploration with different base scores.
7682

77-
??? Example "The default class-wise scorer"
83+
!!! Example "The default class-wise scorer"
84+
Constructing the CWS scorer requires choosing a metric and the functions $f$
85+
and $g$:
86+
7887
```python
7988
import numpy as np
8089
from pydvl.value.shapley.classwise import ClasswiseScorer
@@ -96,156 +105,163 @@ and $g$ for an exploration with different base scores.
96105

97106
## Evaluation
98107

99-
We evaluate the method on the nine datasets used in [@schoch_csshapley_2022],
100-
using the same pre-processing. For images, PCA is used to project the feature, found
101-
by a pre-trained `Resnet18` model, to 32 principal components. A loc-scale normalization
102-
is performed for all models, except gradient boosting. The latter is not sensitive to
103-
the scale of the features. The following table shows the datasets used in the
104-
105-
| Dataset | Data Type | Classes | Input Dims | OpenML ID |
106-
|----------------|-----------|---------|------------|-----------|
107-
| Diabetes | Tabular | 2 | 8 | 37 |
108-
| Click | Tabular | 2 | 11 | 1216 |
109-
| CPU | Tabular | 2 | 21 | 197 |
110-
| Covertype | Tabular | 7 | 54 | 1596 |
111-
| Phoneme | Tabular | 2 | 5 | 1489 |
112-
| FMNIST | Image | 2 | 32 | 40996 |
113-
| CIFAR10 | Image | 2 | 32 | 40927 |
114-
| MNIST (binary) | Image | 2 | 32 | 554 |
115-
| MNIST (multi) | Image | 10 | 32 | 554 |
116-
117-
experiments. In general there are three different experiments: point removal, noise
118-
removal and a distribution analysis. Metrics are evaluated as tables for mean and
119-
coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an inner metric. The former
120-
displays the performance of the method, whereas the latter displays the stability of a
121-
method. We normalize by the mean to make standard deviations for different runs
122-
comparable.
123-
124-
the method. We assume the mean has to be maximized and the CV has to be minimized.
125-
Furthermore, we remark that for all sampling-based valuation methods the same number of
126-
_evaluations of the marginal utility_ was used. This is important, to make the
127-
algorithms comparable. In practice one should consider using a more sophisticated
128-
stopping criterion.
129-
130-
### Dataset pruning for logistic regression
131-
132-
Weighted accuracy drop (WAD) [@schoch_csshapley_2022] is defined as
108+
We illustrate the method with two experiments: point removal and noise removal,
109+
as well as an analysis of the distribution of the values. For this we employ the
110+
nine datasets used in [@schoch_csshapley_2022], using the same pre-processing.
111+
For images, PCA is used to reduce down to 32 the features found by a pre-trained
112+
`Resnet18` model. Standard loc-scale normalization is performed for all models
113+
except gradient boosting, since the latter is not sensitive to the scale of the
114+
features.
115+
116+
??? info "Datasets used for evaluation"
117+
| Dataset | Data Type | Classes | Input Dims | OpenML ID |
118+
|----------------|-----------|---------|------------|-----------|
119+
| Diabetes | Tabular | 2 | 8 | 37 |
120+
| Click | Tabular | 2 | 11 | 1216 |
121+
| CPU | Tabular | 2 | 21 | 197 |
122+
| Covertype | Tabular | 7 | 54 | 1596 |
123+
| Phoneme | Tabular | 2 | 5 | 1489 |
124+
| FMNIST | Image | 2 | 32 | 40996 |
125+
| CIFAR10 | Image | 2 | 32 | 40927 |
126+
| MNIST (binary) | Image | 2 | 32 | 554 |
127+
| MNIST (multi) | Image | 10 | 32 | 554 |
128+
129+
We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner
130+
metric". The former shows the performance of the method, whereas the latter
131+
displays its stability: we normalize by the mean to see the relative effect of
132+
the standard deviation. Ideally the mean value is maximal and CV minimal.
133+
134+
Finally, we note that for all sampling-based valuation methods the same number
135+
of _evaluations of the marginal utility_ was used. This is important to make the
136+
algorithms comparable, but in practice one should consider using a more
137+
sophisticated stopping criterion.
138+
139+
### Dataset pruning for logistic regression (point removal)
140+
141+
In (best-)point removal, one first computes values for the training set and then
142+
removes in sequence the points with the highest values. After each removal, the
143+
remaining points are used to train the model from scratch and performance is
144+
measured on a test set. This produces a curve of performance vs. number of
145+
points removed which we show below.
146+
147+
As a scalar summary of this curve, [@schoch_csshapley_2022] define **Weighted
148+
Accuracy Drop** (WAD) as:
133149

134150
$$
135151
\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j}
136152
a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right)
137153
= a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} ,
138154
$$
139155

140-
where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$ and
141-
$T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1 \colon j \}$. The
142-
metric was evaluated over five runs and is summarized by mean $\mu_\text{WAD}$ and
143-
standard deviation $\sigma_\text{WAD}$. The valuation of the training samples and the
144-
evaluation on the validation samples are both calculated based on a logistic regression
145-
model. Let's have a look at the mean
156+
where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$
157+
and $T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1, \dots , j
158+
\}$.
146159

147-
![Weighted accuracy drop
148-
(Mean)](img/classwise-shapley-metric-wad-mean.svg){ align=left width=50% class=invertible }
160+
We run the point removal experiment for a logistic regression model five times
161+
and compute WAD for each run, then report the mean $\mu_\text{WAD}$ and standard
162+
deviation $\sigma_\text{WAD}$.
149163

150-
of the metric WAD. The table shows that CWS is competitive with all three other methods.
151-
In all problems except `MNIST (multi)` it is better than TMCS, whereas in that
152-
case TMCS has a slight advantage. Another important quantity is the CV. The results are
153-
shown below.
164+
![Mean WAD for best-point removal on logistic regression. Values
165+
computed using LOO, CWS, Beta Shapley, and TMCS
166+
](img/classwise-shapley-metric-wad-mean.svg){ class=invertible }
154167

155-
![Weighted accuracy drop
156-
(CV)](img/classwise-shapley-metric-wad-cv.svg){ align=left width=50% class=invertible }
168+
We see that CWS is competitive with all three other methods. In all problems
169+
except `MNIST (multi)` it outperforms TMCS, while in that case TMCS has a slight
170+
advantage.
157171

158-
It is noteworthy that CWS is not the best method in terms of CV (Lower CV means better
159-
performance). For `CIFAR10`, `Click`, `CPU` and `MNIST (binary)` Beta Shapley has the
160-
lowest CV. For `Diabetes`, `MNIST (multi)` and `Phoneme` CWS is the winner and for
161-
`FMNIST` and `Covertype` TMCS takes the lead. Without considering LOO, TMCS has the
162-
highest relative standard deviation.
172+
In order to understand the variability of WAD we look at its coefficient of
173+
variation (lower is better):
163174

164-
The following plot shows valuation-set accuracy of logistic regression on the y-axis.
165-
The x-axis shows the number of samples removed. Random values serve as a baseline.
166-
Each line represents five runs, whereas bootstrapping was used to estimate the 95%
167-
confidence intervals.
175+
![Coefficient of Variation of WAD for best-point removal on logistic regression.
176+
Values computed using LOO, CWS, Beta Shapley, and TMCS
177+
](img/classwise-shapley-metric-wad-cv.svg){ class=invertible }
168178

179+
CWS is not the best method in terms of CV. For `CIFAR10`, `Click`, `CPU` and
180+
`MNIST (binary)` Beta Shapley has the lowest CV. For `Diabetes`, `MNIST (multi)`
181+
and `Phoneme` CWS is the winner and for `FMNIST` and `Covertype` TMCS takes the
182+
lead. Besides LOO, TMCS has the highest relative standard deviation.
169183

170-
![Accuracy after sample removal using values from logistic
184+
The following plot shows accuracy vs number of samples removed. Random values
185+
serve as a baseline. The shaded area represents the 95% bootstrap confidence
186+
interval of the mean across 5 runs.
187+
188+
![Accuracy after best-sample removal using values from logistic
171189
regression](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg){ class=invertible }
172190

173-
Samples are removed from high to low valuation order and hence we expect a steep
174-
decrease in the curve. Overall we conclude that in terms of mean WAD CWS and TMCS are
175-
the best methods. In terms of CV, CWS and Beta Shapley are the clear winners. Hence, CWS
176-
is a competitive CV.
191+
Because samples are removed from high to low valuation order, we expect a steep
192+
decrease in the curve.
193+
194+
Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with
195+
CWS's CV on par with Beta Shapley's, making CWS a competitive method.
196+
197+
198+
### Dataset pruning for a neural network by value transfer
199+
200+
Transfer of values from one model to another is probably of greater practical
201+
relevance: values are computed using a cheap model and used to prune the dataset
202+
before training a more expensive one.
177203

178-
### Dataset pruning for neural network by value transfer
204+
The following plot shows accuracy vs number of samples removed for transfer from
205+
logistic regression to a neural network. The shaded area represents the 95%
206+
bootstrap confidence interval of the mean across 5 runs.
179207

180-
Practically more relevant is the transfer of values from one model to another one. As
181-
before the values are calculated using logistic regression. However, this time they are
182-
used to prune the training set for a neural network. The following plot shows
183-
valuation-set accuracy of the network on the y-axis, and the number of samples removed
184-
on the x-axis.
208+
![Accuracy after sample removal using values transferred from logistic
209+
regression to an MLP
210+
](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg){ class=invertible }
185211

186-
![Accuracy after sample removal using values transferred from logistic regression to an
187-
MLP](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg){ class=invertible }
212+
As in the previous experiment samples are removed from high to low valuation
213+
order and hence we expect a steep decrease in the curve. CWS is competitive with
214+
the other methods, especially in very unbalanced datasets like `Click`. In other
215+
datasets, like `Covertype`, `Diabetes` and `MNIST (multi)` the performance is on
216+
par with TMCS.
188217

189-
As in the previous experiment samples are removed from high to low valuation order and
190-
hence we expect a steep decrease in the curve. CWS is competitive with the compared
191-
methods. Especially in very unbalanced datasets, like `Click`, the performance of CWS
192-
seems superior. In other datasets, like `Covertype` and `Diabetes` and `MNIST (multi)`
193-
the performance is on par with TMC. For `MNIST (binary)` and `Phoneme` the performance
194-
is competitive.
195218

196-
### Detection of mis-labelled data points
219+
### Detection of mis-labeled data points
197220

198-
The next experiment uses the algorithms to detect mis-labelled data points. 20% of the
199-
indices is selected by choice. Multi-class datasets are discarded, because they do not
200-
possess a unique flipping strategy. The following table shows the mean of the area under
201-
the curve (AUC) for five runs.
221+
The next experiment tries to detect mis-labeled data points in binary
222+
classification tasks. 20% of the indices is flipped at random (we don't consider
223+
multi-class datasets because there isn't a unique flipping strategy). The
224+
following table shows the mean of the area under the curve (AUC) for five runs.
202225

203-
![Area under the Curve
204-
(Mean)](img/classwise-shapley-metric-auc-mean.svg){ align=left width=50% class=invertible }
226+
![Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS,
227+
Beta Shapley, and
228+
TMCS](img/classwise-shapley-metric-auc-mean.svg){ class=invertible }
205229

206-
In the majority of the cases TMCS has a slight advantage over CWS on average. For
207-
`Click` CWS has a slight edge, most probably due to the unbalanced nature of `Click`.
208-
The following plot shows the CV for the AUC of the five runs.
230+
In the majority of cases TMCS has a slight advantage over CWS, except for
231+
`Click`, where CWS has a slight edge, most probably due to the unbalanced nature
232+
of the dataset. The following plot shows the CV for the AUC of the five runs.
209233

210-
![Area under the Curve
211-
(CV)](img/classwise-shapley-metric-auc-cv.svg){ align=left width=50% class=invertible }
234+
![Coefficient of variation of AUC for mis-labeled data point detection. Values
235+
computed using LOO, CWS, Beta Shapley, and TMCS
236+
](img/classwise-shapley-metric-auc-cv.svg){ class=invertible }
212237

213-
In terms of CV, CWS has a clear edge over TMCS and Beta Shapley. The receiving operator
214-
characteristic (ROC) curve is a plot of the precision to the recall. The classifier
215-
uses the $n$-smallest values
216-
respect to the order of the valuation. The following plot shows thec (ROC) for the mean
217-
of five runs.
238+
In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.
218239

219-
![Receiver Operating
220-
Characteristic](img/classwise-shapley-roc-auc-logistic-regression.svg){ align=left width=50% class=invertible }
240+
Finally, we look at the ROC curves training the classifier on the $n$ first
241+
samples in _increasing_ order of valuation (i.e. starting with the worst):
221242

222-
Although it seems that TMCS is the winner: If you consider sample efficiency,
223-
CWS stays competitive. For a perfectly balanced dataset, CWS needs fewer samples than
224-
TCMS on average. Furthermore, CWS is almost on par with TCMS performance-wise.
243+
![Mean ROC across 5 runs with 95% bootstrap
244+
CI](img/classwise-shapley-roc-auc-logistic-regression.svg){ class=invertible }
225245

226-
### Density of values
246+
Although at first sight TMCS seems to be the winner, CWS stays competitive after
247+
factoring in running time. For a perfectly balanced dataset, CWS needs on
248+
average fewer samples than TCMS.
227249

228-
This experiment compares the distribution of values for TMCS (green) and CWS
229-
(red). Both methods are chosen due to their competitiveness. The plot shows a
230-
histogram as well as the density estimated by kernel density estimation (KDE) for each
231-
dataset.
250+
### Value distribution
232251

233-
![Density of TMCS and
234-
CWS](img/classwise-shapley-density.svg){ class=invertible }
252+
For illustration, we compare the distribution of values computed by TMCS and
253+
CWS.
235254

236-
Similar to the behaviour of the CV from the previous section, the variance of CWS is
237-
lower than for TCMS. They seem to approximate the same mode although their utility
238-
function is very different.
255+
![Histogram and estimated density of the values computed by TMCS and
256+
CWS on all nine datasets](img/classwise-shapley-density.svg){ class=invertible }
239257

240-
For `Click` TMCS has a multi-modal distribution of values. This is inferior to CWS which
241-
has only one-mode and is more stable on that dataset. `Click` is a very unbalanced
242-
dataset, and we conclude that CWS seems to be more robust on unbalanced datasets.
258+
For `Click` TMCS has a multi-modal distribution of values. We hypothesize that
259+
this is due to the highly unbalanced nature of the dataset, and notice that CWS
260+
has a single mode, leading to its greater performance on this dataset.
243261

244262
## Conclusion
245263

246-
CWS is a reasonable and effective way to handle classification problems. It reduces the
247-
computing power and variance by splitting up the data set into classes. Given the
248-
underlying similarities in the architecture of TMCS, Beta Shapley, and CWS, there's a
249-
clear pathway for improving convergence rates, sample efficiency, and stabilizing
250-
variance for TMCS and Beta Shapley.
264+
CWS is an effective way to handle classification problems, in particular for
265+
unbalanced datasets. It reduces the computing requirements by considering
266+
in-class and out-of-class points separately.
251267

docs_includes/abbreviations.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
*[MCLC]: Monte Carlo Least Core
1212
*[MCS]: Monte Carlo Shapley
1313
*[ML]: Machine Learning
14+
*[MLP]: Multi-Layer Perceptron
1415
*[MLRC]: Machine Learning Reproducibility Challenge
1516
*[MSE]: Mean Squared Error
1617
*[PCA]: Principal Component Analysis
18+
*[ROC]: Receiver Operating Characteristic
1719
*[SV]: Shapley Value
18-
*[TMCS]: Truncated Monte Carlo Shapley
20+
*[TMCS]: Truncated Monte Carlo Shapley
21+
*[WAD]: Weighted Accuracy Drop

0 commit comments

Comments
 (0)