|
| 1 | +--- |
| 2 | +title: Class-wise Shapley |
| 3 | +--- |
| 4 | + |
| 5 | +# Class-wise Shapley |
| 6 | + |
| 7 | +Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework |
| 8 | +tailored for classification problems. Given a sample $x_i$ with label $y_i \in |
| 9 | +\mathbb{N}$, let $D_{y_i}$ be the subset of $D$ with labels $y_i$, and |
| 10 | +$D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that the |
| 11 | +sample $(x_i, y_i)$ might improve the overall model performance on $D$, while |
| 12 | +being detrimental for the performance on $D_{y_i},$ e.g. because of a wrong |
| 13 | +label. To address this issue, the authors introduced |
| 14 | + |
| 15 | +$$ |
| 16 | +v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}} |
| 17 | +\left [ |
| 18 | +\frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1} |
| 19 | +\delta(S_{y_i} | S_{-y_i}) |
| 20 | +\right ], |
| 21 | +$$ |
| 22 | + |
| 23 | +where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq |
| 24 | +D_{-y_i}$ is _arbitrary_ (in particular, not the complement of $S_{y_i}$). The |
| 25 | +function $\delta$ is called **set-conditional marginal Shapley value** and is |
| 26 | +defined as |
| 27 | + |
| 28 | +$$ |
| 29 | +\delta(S | C) = u( S_{+i} | C ) − u(S | C), |
| 30 | +$$ |
| 31 | + |
| 32 | +for any set $S$ such that $i \notin S, C$ and $S \cap C = \emptyset$. |
| 33 | + |
| 34 | +In practical applications, estimating this quantity is done both with Monte |
| 35 | +Carlo sampling of the powerset, and the set of index permutations |
| 36 | +[@castro_polynomial_2009]. Typically, this requires fewer samples than the |
| 37 | +original Shapley value, although the actual speed-up depends on the model and |
| 38 | +the dataset. |
| 39 | + |
| 40 | + |
| 41 | +!!! Example "Computing classwise Shapley values" |
| 42 | + Like all other game-theoretic valuation methods, CWS requires a |
| 43 | + [Utility][pydvl.utils.utility.Utility] object constructed with model and |
| 44 | + dataset, with the peculiarity of requiring a specific |
| 45 | + [ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. The entry |
| 46 | + point is the function |
| 47 | + [compute_classwise_shapley_values][pydvl.value.shapley.classwise.compute_classwise_shapley_values]: |
| 48 | + |
| 49 | + ```python |
| 50 | + from pydvl.value import * |
| 51 | + |
| 52 | + model = ... |
| 53 | + data = Dataset(...) |
| 54 | + scorer = ClasswiseScorer(...) |
| 55 | + utility = Utility(model, data, scorer) |
| 56 | + values = compute_classwise_shapley_values( |
| 57 | + utility, |
| 58 | + done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000), |
| 59 | + truncation=RelativeTruncation(utility, rtol=0.01), |
| 60 | + done_sample_complements=MaxChecks(1), |
| 61 | + normalize_values=True |
| 62 | + ) |
| 63 | + ``` |
| 64 | + |
| 65 | + |
| 66 | +### The class-wise scorer |
| 67 | + |
| 68 | +In order to use the classwise Shapley value, one needs to define a |
| 69 | +[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. This scorer |
| 70 | +is defined as |
| 71 | + |
| 72 | +$$ |
| 73 | +u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})), |
| 74 | +$$ |
| 75 | + |
| 76 | +where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the |
| 77 | +**in-class accuracy**, and $a_S(D_{-y_i})$ is the **out-of-class accuracy** (the |
| 78 | +names originate from a choice by the authors to use accuracy, but in principle |
| 79 | +any other score, like $F_1$ can be used). |
| 80 | + |
| 81 | +The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are |
| 82 | +therefore the defaults, but we leave the option to set different functions $f$ |
| 83 | +and $g$ for an exploration with different base scores. |
| 84 | + |
| 85 | +!!! Example "The default class-wise scorer" |
| 86 | + Constructing the CWS scorer requires choosing a metric and the functions $f$ |
| 87 | + and $g$: |
| 88 | + |
| 89 | + ```python |
| 90 | + import numpy as np |
| 91 | + from pydvl.value.shapley.classwise import ClasswiseScorer |
| 92 | + |
| 93 | + # These are the defaults |
| 94 | + identity = lambda x: x |
| 95 | + scorer = ClasswiseScorer( |
| 96 | + "accuracy", |
| 97 | + in_class_discount_fn=identity, |
| 98 | + out_of_class_discount_fn=np.exp |
| 99 | + ) |
| 100 | + ``` |
| 101 | + |
| 102 | +??? "Surface of the discounted utility function" |
| 103 | + The level curves for $f(x)=x$ and $g(x)=e^x$ are depicted below. The lines |
| 104 | + illustrate the contour lines, annotated with their respective gradients. |
| 105 | + { align=left width=33% class=invertible } |
| 107 | + |
| 108 | +## Evaluation |
| 109 | + |
| 110 | +We illustrate the method with two experiments: point removal and noise removal, |
| 111 | +as well as an analysis of the distribution of the values. For this we employ the |
| 112 | +nine datasets used in [@schoch_csshapley_2022], using the same pre-processing. |
| 113 | +For images, PCA is used to reduce down to 32 the features found by a pre-trained |
| 114 | +`Resnet18` model. Standard loc-scale normalization is performed for all models |
| 115 | +except gradient boosting, since the latter is not sensitive to the scale of the |
| 116 | +features. |
| 117 | + |
| 118 | +??? info "Datasets used for evaluation" |
| 119 | + | Dataset | Data Type | Classes | Input Dims | OpenML ID | |
| 120 | + |----------------|-----------|---------|------------|-----------| |
| 121 | + | Diabetes | Tabular | 2 | 8 | 37 | |
| 122 | + | Click | Tabular | 2 | 11 | 1216 | |
| 123 | + | CPU | Tabular | 2 | 21 | 197 | |
| 124 | + | Covertype | Tabular | 7 | 54 | 1596 | |
| 125 | + | Phoneme | Tabular | 2 | 5 | 1489 | |
| 126 | + | FMNIST | Image | 2 | 32 | 40996 | |
| 127 | + | CIFAR10 | Image | 2 | 32 | 40927 | |
| 128 | + | MNIST (binary) | Image | 2 | 32 | 554 | |
| 129 | + | MNIST (multi) | Image | 10 | 32 | 554 | |
| 130 | + |
| 131 | +We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner |
| 132 | +metric". The former shows the performance of the method, whereas the latter |
| 133 | +displays its stability: we normalize by the mean to see the relative effect of |
| 134 | +the standard deviation. Ideally the mean value is maximal and CV minimal. |
| 135 | + |
| 136 | +Finally, we note that for all sampling-based valuation methods the same number |
| 137 | +of _evaluations of the marginal utility_ was used. This is important to make the |
| 138 | +algorithms comparable, but in practice one should consider using a more |
| 139 | +sophisticated stopping criterion. |
| 140 | + |
| 141 | +### Dataset pruning for logistic regression (point removal) |
| 142 | + |
| 143 | +In (best-)point removal, one first computes values for the training set and then |
| 144 | +removes in sequence the points with the highest values. After each removal, the |
| 145 | +remaining points are used to train the model from scratch and performance is |
| 146 | +measured on a test set. This produces a curve of performance vs. number of |
| 147 | +points removed which we show below. |
| 148 | + |
| 149 | +As a scalar summary of this curve, [@schoch_csshapley_2022] define **Weighted |
| 150 | +Accuracy Drop** (WAD) as: |
| 151 | + |
| 152 | +$$ |
| 153 | +\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j} |
| 154 | +a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right) |
| 155 | += a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} , |
| 156 | +$$ |
| 157 | + |
| 158 | +where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$ |
| 159 | +and $T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1, \dots , j |
| 160 | +\}$. |
| 161 | + |
| 162 | +We run the point removal experiment for a logistic regression model five times |
| 163 | +and compute WAD for each run, then report the mean $\mu_\text{WAD}$ and standard |
| 164 | +deviation $\sigma_\text{WAD}$. |
| 165 | + |
| 166 | +{ class=invertible } |
| 169 | + |
| 170 | +We see that CWS is competitive with all three other methods. In all problems |
| 171 | +except `MNIST (multi)` it outperforms TMCS, while in that case TMCS has a slight |
| 172 | +advantage. |
| 173 | + |
| 174 | +In order to understand the variability of WAD we look at its coefficient of |
| 175 | +variation (lower is better): |
| 176 | + |
| 177 | +{ class=invertible } |
| 180 | + |
| 181 | +CWS is not the best method in terms of CV. For `CIFAR10`, `Click`, `CPU` and |
| 182 | +`MNIST (binary)` Beta Shapley has the lowest CV. For `Diabetes`, `MNIST (multi)` |
| 183 | +and `Phoneme` CWS is the winner and for `FMNIST` and `Covertype` TMCS takes the |
| 184 | +lead. Besides LOO, TMCS has the highest relative standard deviation. |
| 185 | + |
| 186 | +The following plot shows accuracy vs number of samples removed. Random values |
| 187 | +serve as a baseline. The shaded area represents the 95% bootstrap confidence |
| 188 | +interval of the mean across 5 runs. |
| 189 | + |
| 190 | +{ class=invertible } |
| 192 | + |
| 193 | +Because samples are removed from high to low valuation order, we expect a steep |
| 194 | +decrease in the curve. |
| 195 | + |
| 196 | +Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with |
| 197 | +CWS's CV on par with Beta Shapley's, making CWS a competitive method. |
| 198 | + |
| 199 | + |
| 200 | +### Dataset pruning for a neural network by value transfer |
| 201 | + |
| 202 | +Transfer of values from one model to another is probably of greater practical |
| 203 | +relevance: values are computed using a cheap model and used to prune the dataset |
| 204 | +before training a more expensive one. |
| 205 | + |
| 206 | +The following plot shows accuracy vs number of samples removed for transfer from |
| 207 | +logistic regression to a neural network. The shaded area represents the 95% |
| 208 | +bootstrap confidence interval of the mean across 5 runs. |
| 209 | + |
| 210 | +{ class=invertible } |
| 213 | + |
| 214 | +As in the previous experiment samples are removed from high to low valuation |
| 215 | +order and hence we expect a steep decrease in the curve. CWS is competitive with |
| 216 | +the other methods, especially in very unbalanced datasets like `Click`. In other |
| 217 | +datasets, like `Covertype`, `Diabetes` and `MNIST (multi)` the performance is on |
| 218 | +par with TMCS. |
| 219 | + |
| 220 | + |
| 221 | +### Detection of mis-labeled data points |
| 222 | + |
| 223 | +The next experiment tries to detect mis-labeled data points in binary |
| 224 | +classification tasks. 20% of the indices is flipped at random (we don't consider |
| 225 | +multi-class datasets because there isn't a unique flipping strategy). The |
| 226 | +following table shows the mean of the area under the curve (AUC) for five runs. |
| 227 | + |
| 228 | +{ class=invertible } |
| 231 | + |
| 232 | +In the majority of cases TMCS has a slight advantage over CWS, except for |
| 233 | +`Click`, where CWS has a slight edge, most probably due to the unbalanced nature |
| 234 | +of the dataset. The following plot shows the CV for the AUC of the five runs. |
| 235 | + |
| 236 | +{ class=invertible } |
| 239 | + |
| 240 | +In terms of CV, CWS has a clear edge over TMCS and Beta Shapley. |
| 241 | + |
| 242 | +Finally, we look at the ROC curves training the classifier on the $n$ first |
| 243 | +samples in _increasing_ order of valuation (i.e. starting with the worst): |
| 244 | + |
| 245 | +{ class=invertible } |
| 247 | + |
| 248 | +Although at first sight TMCS seems to be the winner, CWS stays competitive after |
| 249 | +factoring in running time. For a perfectly balanced dataset, CWS needs on |
| 250 | +average fewer samples than TCMS. |
| 251 | + |
| 252 | +### Value distribution |
| 253 | + |
| 254 | +For illustration, we compare the distribution of values computed by TMCS and |
| 255 | +CWS. |
| 256 | + |
| 257 | +{ class=invertible } |
| 259 | + |
| 260 | +For `Click` TMCS has a multi-modal distribution of values. We hypothesize that |
| 261 | +this is due to the highly unbalanced nature of the dataset, and notice that CWS |
| 262 | +has a single mode, leading to its greater performance on this dataset. |
| 263 | + |
| 264 | +## Conclusion |
| 265 | + |
| 266 | +CWS is an effective way to handle classification problems, in particular for |
| 267 | +unbalanced datasets. It reduces the computing requirements by considering |
| 268 | +in-class and out-of-class points separately. |
| 269 | + |
0 commit comments