@@ -5,11 +5,12 @@ title: Class-wise Shapley
55# Class-wise Shapley
66
77Class-wise Shapley (CWS) [ @schoch_csshapley_2022] offers a Shapley framework
8- tailored for classification problems. Let $D$ be a dataset, $D_ {y_i}$ be the
9- subset of $D$ with labels $y_i$, and $D_ {-y_i}$ be the complement of $D_ {y_i}$
10- in $D$. The key idea is that a sample $(x_i, y_i)$ might enhance the overall
11- performance on $D$, while being detrimental for the performance on $D_ {y_i}$. To
12- address this issue, the authors introduced
8+ tailored for classification problems. Given a sample $x_i$ with label $y_i \in
9+ \mathbb{N}$, let $D_ {y_i}$ be the subset of $D$ with labels $y_i$, and
10+ $D_ {-y_i}$ be the complement of $D_ {y_i}$ in $D$. The key idea is that the
11+ sample $(x_i, y_i)$ might improve the overall model performance on $D$, while
12+ being detrimental for the performance on $D_ {y_i},$ e.g. because of a wrong
13+ label. To address this issue, the authors introduced
1314
1415$$
1516v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
@@ -20,14 +21,15 @@ v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
2021$$
2122
2223where $S_ {y_i} \subseteq D_ {y_i} \setminus \{ i\} $ and $S_ {-y_i} \subseteq
23- D_ {-y_i}$, and the function $\delta$ is called ** set-conditional marginal
24- Shapley value** . It is defined as
24+ D_ {-y_i}$ is _ arbitrary_ (in particular, not the complement of $S_ {y_i}$). The
25+ function $\delta$ is called ** set-conditional marginal Shapley value** and is
26+ defined as
2527
2628$$
27- \delta(S | C) = u( S \cup \{i\ } | C ) − u(S | C),
29+ \delta(S | C) = u( S_{+i } | C ) − u(S | C),
2830$$
2931
30- where $ i \notin S, C$ and $S \bigcap C = \emptyset$.
32+ for any set $S$ such that $ i \notin S, C$ and $S \cap C = \emptyset$.
3133
3234In practical applications, estimating this quantity is done both with Monte
3335Carlo sampling of the powerset, and the set of index permutations
@@ -36,7 +38,12 @@ original Shapley value, although the actual speed-up depends on the model and
3638the dataset.
3739
3840
39- ??? Example "Computing classwise Shapley values"
41+ !!! Example "Computing classwise Shapley values"
42+ Like all other game-theoretic valuation methods, CWS requires a
43+ [ Utility] [ pydvl.utils.utility.Utility ] object constructed with model and
44+ dataset, with the peculiarity of requiring a specific
45+ [ ClasswiseScorer] [ pydvl.value.shapley.classwise.ClasswiseScorer ] :
46+
4047 ```python
4148 from pydvl.value import *
4249
@@ -54,15 +61,14 @@ the dataset.
5461 ```
5562
5663
57- ### Class -wise scorer
64+ ### The class -wise scorer
5865
5966In order to use the classwise Shapley value, one needs to define a
60- [ ClasswiseScorer] [ pydvl.value.shapley.classwise.ClasswiseScorer ] . Given a sample
61- $x_i$ with label $y_i \in \mathbb{N}$, we define two disjoint sets $D_ {y_i}$ and
62- $D_ {-y_i}$ and define
67+ [ ClasswiseScorer] [ pydvl.value.shapley.classwise.ClasswiseScorer ] . This scorer
68+ is defined as
6369
6470$$
65- u(S) = f(a_S(D_{y_i}))) g(a_S(D_{-y_i}) )),
71+ u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})),
6672$$
6773
6874where $f$ and $g$ are monotonically increasing functions, $a_S(D_ {y_i})$ is the
@@ -74,7 +80,10 @@ The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
7480therefore the defaults, but we leave the option to set different functions $f$
7581and $g$ for an exploration with different base scores.
7682
77- ??? Example "The default class-wise scorer"
83+ !!! Example "The default class-wise scorer"
84+ Constructing the CWS scorer requires choosing a metric and the functions $f$
85+ and $g$:
86+
7887 ```python
7988 import numpy as np
8089 from pydvl.value.shapley.classwise import ClasswiseScorer
@@ -96,156 +105,163 @@ and $g$ for an exploration with different base scores.
96105
97106## Evaluation
98107
99- We evaluate the method on the nine datasets used in [ @schoch_csshapley_2022] ,
100- using the same pre-processing. For images, PCA is used to project the feature, found
101- by a pre-trained ` Resnet18 ` model, to 32 principal components. A loc-scale normalization
102- is performed for all models, except gradient boosting. The latter is not sensitive to
103- the scale of the features. The following table shows the datasets used in the
104-
105- | Dataset | Data Type | Classes | Input Dims | OpenML ID |
106- | ----------------| -----------| ---------| ------------| -----------|
107- | Diabetes | Tabular | 2 | 8 | 37 |
108- | Click | Tabular | 2 | 11 | 1216 |
109- | CPU | Tabular | 2 | 21 | 197 |
110- | Covertype | Tabular | 7 | 54 | 1596 |
111- | Phoneme | Tabular | 2 | 5 | 1489 |
112- | FMNIST | Image | 2 | 32 | 40996 |
113- | CIFAR10 | Image | 2 | 32 | 40927 |
114- | MNIST (binary) | Image | 2 | 32 | 554 |
115- | MNIST (multi) | Image | 10 | 32 | 554 |
116-
117- experiments. In general there are three different experiments: point removal, noise
118- removal and a distribution analysis. Metrics are evaluated as tables for mean and
119- coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an inner metric. The former
120- displays the performance of the method, whereas the latter displays the stability of a
121- method. We normalize by the mean to make standard deviations for different runs
122- comparable.
123-
124- the method. We assume the mean has to be maximized and the CV has to be minimized.
125- Furthermore, we remark that for all sampling-based valuation methods the same number of
126- _ evaluations of the marginal utility_ was used. This is important, to make the
127- algorithms comparable. In practice one should consider using a more sophisticated
128- stopping criterion.
129-
130- ### Dataset pruning for logistic regression
131-
132- Weighted accuracy drop (WAD) [ @schoch_csshapley_2022] is defined as
108+ We illustrate the method with two experiments: point removal and noise removal,
109+ as well as an analysis of the distribution of the values. For this we employ the
110+ nine datasets used in [ @schoch_csshapley_2022] , using the same pre-processing.
111+ For images, PCA is used to reduce down to 32 the features found by a pre-trained
112+ ` Resnet18 ` model. Standard loc-scale normalization is performed for all models
113+ except gradient boosting, since the latter is not sensitive to the scale of the
114+ features.
115+
116+ ??? info "Datasets used for evaluation"
117+ | Dataset | Data Type | Classes | Input Dims | OpenML ID |
118+ |----------------|-----------|---------|------------|-----------|
119+ | Diabetes | Tabular | 2 | 8 | 37 |
120+ | Click | Tabular | 2 | 11 | 1216 |
121+ | CPU | Tabular | 2 | 21 | 197 |
122+ | Covertype | Tabular | 7 | 54 | 1596 |
123+ | Phoneme | Tabular | 2 | 5 | 1489 |
124+ | FMNIST | Image | 2 | 32 | 40996 |
125+ | CIFAR10 | Image | 2 | 32 | 40927 |
126+ | MNIST (binary) | Image | 2 | 32 | 554 |
127+ | MNIST (multi) | Image | 10 | 32 | 554 |
128+
129+ We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner
130+ metric". The former shows the performance of the method, whereas the latter
131+ displays its stability: we normalize by the mean to see the relative effect of
132+ the standard deviation. Ideally the mean value is maximal and CV minimal.
133+
134+ Finally, we note that for all sampling-based valuation methods the same number
135+ of _ evaluations of the marginal utility_ was used. This is important to make the
136+ algorithms comparable, but in practice one should consider using a more
137+ sophisticated stopping criterion.
138+
139+ ### Dataset pruning for logistic regression (point removal)
140+
141+ In (best-)point removal, one first computes values for the training set and then
142+ removes in sequence the points with the highest values. After each removal, the
143+ remaining points are used to train the model from scratch and performance is
144+ measured on a test set. This produces a curve of performance vs. number of
145+ points removed which we show below.
146+
147+ As a scalar summary of this curve, [ @schoch_csshapley_2022] define ** Weighted
148+ Accuracy Drop** (WAD) as:
133149
134150$$
135151\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j}
136152a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right)
137153= a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} ,
138154$$
139155
140- where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$ and
141- $T_ {-\{ 1 \colon j \} }$ is the set $T$ without elements from $\{ 1 \colon j \} $. The
142- metric was evaluated over five runs and is summarized by mean $\mu_ \text{WAD}$ and
143- standard deviation $\sigma_ \text{WAD}$. The valuation of the training samples and the
144- evaluation on the validation samples are both calculated based on a logistic regression
145- model. Let's have a look at the mean
156+ where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$
157+ and $T_ {-\{ 1 \colon j \} }$ is the set $T$ without elements from $\{ 1, \dots , j
158+ \} $.
146159
147- ![ Weighted accuracy drop
148- (Mean)] ( img/classwise-shapley-metric-wad-mean.svg ) { align=left width=50% class=invertible }
160+ We run the point removal experiment for a logistic regression model five times
161+ and compute WAD for each run, then report the mean $\mu_ \text{WAD}$ and standard
162+ deviation $\sigma_ \text{WAD}$.
149163
150- of the metric WAD. The table shows that CWS is competitive with all three other methods.
151- In all problems except ` MNIST (multi) ` it is better than TMCS, whereas in that
152- case TMCS has a slight advantage. Another important quantity is the CV. The results are
153- shown below.
164+ ![ Mean WAD for best-point removal on logistic regression. Values
165+ computed using LOO, CWS, Beta Shapley, and TMCS
166+ ] ( img/classwise-shapley-metric-wad-mean.svg ) { class=invertible }
154167
155- ![ Weighted accuracy drop
156- (CV)] ( img/classwise-shapley-metric-wad-cv.svg ) { align=left width=50% class=invertible }
168+ We see that CWS is competitive with all three other methods. In all problems
169+ except ` MNIST (multi) ` it outperforms TMCS, while in that case TMCS has a slight
170+ advantage.
157171
158- It is noteworthy that CWS is not the best method in terms of CV (Lower CV means better
159- performance). For ` CIFAR10 ` , ` Click ` , ` CPU ` and ` MNIST (binary) ` Beta Shapley has the
160- lowest CV. For ` Diabetes ` , ` MNIST (multi) ` and ` Phoneme ` CWS is the winner and for
161- ` FMNIST ` and ` Covertype ` TMCS takes the lead. Without considering LOO, TMCS has the
162- highest relative standard deviation.
172+ In order to understand the variability of WAD we look at its coefficient of
173+ variation (lower is better):
163174
164- The following plot shows valuation-set accuracy of logistic regression on the y-axis.
165- The x-axis shows the number of samples removed. Random values serve as a baseline.
166- Each line represents five runs, whereas bootstrapping was used to estimate the 95%
167- confidence intervals.
175+ ![ Coefficient of Variation of WAD for best-point removal on logistic regression.
176+ Values computed using LOO, CWS, Beta Shapley, and TMCS
177+ ] ( img/classwise-shapley-metric-wad-cv.svg ) { class=invertible }
168178
179+ CWS is not the best method in terms of CV. For ` CIFAR10 ` , ` Click ` , ` CPU ` and
180+ ` MNIST (binary) ` Beta Shapley has the lowest CV. For ` Diabetes ` , ` MNIST (multi) `
181+ and ` Phoneme ` CWS is the winner and for ` FMNIST ` and ` Covertype ` TMCS takes the
182+ lead. Besides LOO, TMCS has the highest relative standard deviation.
169183
170- ![ Accuracy after sample removal using values from logistic
184+ The following plot shows accuracy vs number of samples removed. Random values
185+ serve as a baseline. The shaded area represents the 95% bootstrap confidence
186+ interval of the mean across 5 runs.
187+
188+ ![ Accuracy after best-sample removal using values from logistic
171189regression] ( img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg ) { class=invertible }
172190
173- Samples are removed from high to low valuation order and hence we expect a steep
174- decrease in the curve. Overall we conclude that in terms of mean WAD CWS and TMCS are
175- the best methods. In terms of CV, CWS and Beta Shapley are the clear winners. Hence, CWS
176- is a competitive CV.
191+ Because samples are removed from high to low valuation order, we expect a steep
192+ decrease in the curve.
193+
194+ Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with
195+ CWS's CV on par with Beta Shapley's, making CWS a competitive method.
196+
197+
198+ ### Dataset pruning for a neural network by value transfer
199+
200+ Transfer of values from one model to another is probably of greater practical
201+ relevance: values are computed using a cheap model and used to prune the dataset
202+ before training a more expensive one.
177203
178- ### Dataset pruning for neural network by value transfer
204+ The following plot shows accuracy vs number of samples removed for transfer from
205+ logistic regression to a neural network. The shaded area represents the 95%
206+ bootstrap confidence interval of the mean across 5 runs.
179207
180- Practically more relevant is the transfer of values from one model to another one. As
181- before the values are calculated using logistic regression. However, this time they are
182- used to prune the training set for a neural network. The following plot shows
183- valuation-set accuracy of the network on the y-axis, and the number of samples removed
184- on the x-axis.
208+ ![ Accuracy after sample removal using values transferred from logistic
209+ regression to an MLP
210+ ] ( img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg ) { class=invertible }
185211
186- ![ Accuracy after sample removal using values transferred from logistic regression to an
187- MLP] ( img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg ) { class=invertible }
212+ As in the previous experiment samples are removed from high to low valuation
213+ order and hence we expect a steep decrease in the curve. CWS is competitive with
214+ the other methods, especially in very unbalanced datasets like ` Click ` . In other
215+ datasets, like ` Covertype ` , ` Diabetes ` and ` MNIST (multi) ` the performance is on
216+ par with TMCS.
188217
189- As in the previous experiment samples are removed from high to low valuation order and
190- hence we expect a steep decrease in the curve. CWS is competitive with the compared
191- methods. Especially in very unbalanced datasets, like ` Click ` , the performance of CWS
192- seems superior. In other datasets, like ` Covertype ` and ` Diabetes ` and ` MNIST (multi) `
193- the performance is on par with TMC. For ` MNIST (binary) ` and ` Phoneme ` the performance
194- is competitive.
195218
196- ### Detection of mis-labelled data points
219+ ### Detection of mis-labeled data points
197220
198- The next experiment uses the algorithms to detect mis-labelled data points. 20% of the
199- indices is selected by choice. Multi-class datasets are discarded, because they do not
200- possess a unique flipping strategy. The following table shows the mean of the area under
201- the curve (AUC) for five runs.
221+ The next experiment tries to detect mis-labeled data points in binary
222+ classification tasks. 20% of the indices is flipped at random (we don't consider
223+ multi-class datasets because there isn't a unique flipping strategy) . The
224+ following table shows the mean of the area under the curve (AUC) for five runs.
202225
203- ![ Area under the Curve
204- (Mean)] ( img/classwise-shapley-metric-auc-mean.svg ) { align=left width=50% class=invertible }
226+ ![ Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS,
227+ Beta Shapley, and
228+ TMCS] ( img/classwise-shapley-metric-auc-mean.svg ) { class=invertible }
205229
206- In the majority of the cases TMCS has a slight advantage over CWS on average. For
207- ` Click ` CWS has a slight edge, most probably due to the unbalanced nature of ` Click ` .
208- The following plot shows the CV for the AUC of the five runs.
230+ In the majority of cases TMCS has a slight advantage over CWS, except for
231+ ` Click ` , where CWS has a slight edge, most probably due to the unbalanced nature
232+ of the dataset. The following plot shows the CV for the AUC of the five runs.
209233
210- ![ Area under the Curve
211- (CV)] ( img/classwise-shapley-metric-auc-cv.svg ) { align=left width=50% class=invertible }
234+ ![ Coefficient of variation of AUC for mis-labeled data point detection. Values
235+ computed using LOO, CWS, Beta Shapley, and TMCS
236+ ] ( img/classwise-shapley-metric-auc-cv.svg ) { class=invertible }
212237
213- In terms of CV, CWS has a clear edge over TMCS and Beta Shapley. The receiving operator
214- characteristic (ROC) curve is a plot of the precision to the recall. The classifier
215- uses the $n$-smallest values
216- respect to the order of the valuation. The following plot shows thec (ROC) for the mean
217- of five runs.
238+ In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.
218239
219- ![ Receiver Operating
220- Characteristic ] ( img/classwise-shapley-roc-auc-logistic-regression.svg ) { align=left width=50% class=invertible }
240+ Finally, we look at the ROC curves training the classifier on the $n$ first
241+ samples in _ increasing _ order of valuation (i.e. starting with the worst):
221242
222- Although it seems that TMCS is the winner: If you consider sample efficiency,
223- CWS stays competitive. For a perfectly balanced dataset, CWS needs fewer samples than
224- TCMS on average. Furthermore, CWS is almost on par with TCMS performance-wise.
243+ ![ Mean ROC across 5 runs with 95% bootstrap
244+ CI] ( img/classwise-shapley-roc-auc-logistic-regression.svg ) { class=invertible }
225245
226- ### Density of values
246+ Although at first sight TMCS seems to be the winner, CWS stays competitive after
247+ factoring in running time. For a perfectly balanced dataset, CWS needs on
248+ average fewer samples than TCMS.
227249
228- This experiment compares the distribution of values for TMCS (green) and CWS
229- (red). Both methods are chosen due to their competitiveness. The plot shows a
230- histogram as well as the density estimated by kernel density estimation (KDE) for each
231- dataset.
250+ ### Value distribution
232251
233- ![ Density of TMCS and
234- CWS] ( img/classwise-shapley-density.svg ) { class=invertible }
252+ For illustration, we compare the distribution of values computed by TMCS and
253+ CWS.
235254
236- Similar to the behaviour of the CV from the previous section, the variance of CWS is
237- lower than for TCMS. They seem to approximate the same mode although their utility
238- function is very different.
255+ ![ Histogram and estimated density of the values computed by TMCS and
256+ CWS on all nine datasets] ( img/classwise-shapley-density.svg ) { class=invertible }
239257
240- For ` Click ` TMCS has a multi-modal distribution of values. This is inferior to CWS which
241- has only one-mode and is more stable on that dataset. ` Click ` is a very unbalanced
242- dataset, and we conclude that CWS seems to be more robust on unbalanced datasets .
258+ For ` Click ` TMCS has a multi-modal distribution of values. We hypothesize that
259+ this is due to the highly unbalanced nature of the dataset, and notice that CWS
260+ has a single mode, leading to its greater performance on this dataset .
243261
244262## Conclusion
245263
246- CWS is a reasonable and effective way to handle classification problems. It reduces the
247- computing power and variance by splitting up the data set into classes. Given the
248- underlying similarities in the architecture of TMCS, Beta Shapley, and CWS, there's a
249- clear pathway for improving convergence rates, sample efficiency, and stabilizing
250- variance for TMCS and Beta Shapley.
264+ CWS is an effective way to handle classification problems, in particular for
265+ unbalanced datasets. It reduces the computing requirements by considering
266+ in-class and out-of-class points separately.
251267
0 commit comments