|
1 | | -from __future__ import print_function |
2 | | - |
3 | | -from aix360.algorithms.lwbe import LocalWBExplainer |
4 | | - |
5 | | -from .CEM_aen import AEADEN |
6 | | - |
7 | | -import random |
8 | | -import numpy as np |
9 | | - |
10 | | - |
11 | | -class CEMExplainer(LocalWBExplainer): |
12 | | - """ |
13 | | - CEMExplainer can be used to compute contrastive explanations for image and tabular data. |
14 | | - This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and |
15 | | - what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification. |
16 | | - We use elastic norm regularization to ensure minimality for both parts of the explanation |
17 | | - i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_ |
18 | | -
|
19 | | - References: |
20 | | - .. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu, |
21 | | - Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on |
22 | | - the Missing: Towards Contrastive Explanations with Pertinent Negatives," |
23 | | - Advances in Neural Information Processing Systems (NeurIPS), 2018. |
24 | | - <https://arxiv.org/abs/1802.07623>`_ |
25 | | - """ |
26 | | - def __init__(self, model): |
27 | | - |
28 | | - """ |
29 | | - Constructor method, initializes the explainer |
30 | | -
|
31 | | - Args: |
32 | | - model: KerasClassifier model whose predictions needs to be explained |
33 | | - """ |
34 | | - super(CEMExplainer, self).__init__() |
35 | | - self._wbmodel = model |
36 | | - |
37 | | - |
38 | | - def set_params(self, *argv, **kwargs): |
39 | | - """ |
40 | | - Set parameters for the explainer. |
41 | | - """ |
42 | | - pass |
43 | | - |
44 | | - |
45 | | - def explain_instance(self, input_X, |
46 | | - arg_mode, AE_model, arg_kappa, arg_b, |
47 | | - arg_max_iter, arg_init_const, arg_beta, arg_gamma): |
48 | | - |
49 | | - """ |
50 | | - Explains an input instance input_X and returns contrastive explanations. |
51 | | - Note that this assumes that the classifier was trained with inputs normalized in [-0.5,0.5] range. |
52 | | -
|
53 | | - Args: |
54 | | - input_X (numpy.ndarray): input instance to be explained |
55 | | - arg_mode (str): 'PP' or 'PN' |
56 | | - AE_model: Auto-encoder model |
57 | | - arg_kappa (double): Confidence gap between desired class and other classes |
58 | | - arg_b (double): Number of different weightings of loss function to try |
59 | | - arg_max_iter (int): For each weighting of loss function number of iterations to search |
60 | | - arg_init_const (double): Initial weighting of loss function |
61 | | - arg_beta (double): Weighting of L1 loss |
62 | | - arg_gamma (double): Weighting of auto-encoder |
63 | | -
|
64 | | - Returns: |
65 | | - tuple: |
66 | | - * **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN |
67 | | - * **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance |
68 | | - * **INFO** (`str`) -- Other information about PP/PN |
69 | | - """ |
70 | | - |
71 | | - random.seed(121) |
72 | | - np.random.seed(1211) |
73 | | - |
74 | | - (_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X) |
75 | | - target_label = orig_class |
76 | | - |
77 | | - target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]]) |
78 | | - |
79 | | - # Hard coding batch_size=1 |
80 | | - batch_size = 1 |
81 | | - |
82 | | - # Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns) |
83 | | - shape = input_X.shape |
84 | | - |
85 | | - attack = AEADEN(self._wbmodel, shape, |
86 | | - mode=arg_mode, AE=AE_model, batch_size=batch_size, |
87 | | - kappa=arg_kappa, init_learning_rate=1e-2, |
88 | | - binary_search_steps=arg_b, max_iterations=arg_max_iter, |
89 | | - initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma) |
90 | | - |
91 | | - adv_X = attack.attack(input_X, target) |
92 | | - |
93 | | - adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X) |
94 | | - |
95 | | - delta_X = input_X - adv_X |
96 | | - |
97 | | - _, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X) |
98 | | - |
99 | | - INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format( |
100 | | - arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str) |
101 | | - |
102 | | - return (adv_X, delta_X, INFO) |
| 1 | +from __future__ import print_function |
| 2 | + |
| 3 | +from aix360.algorithms.lwbe import LocalWBExplainer |
| 4 | + |
| 5 | +from .CEM_aen import AEADEN |
| 6 | + |
| 7 | +import random |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +class CEMExplainer(LocalWBExplainer): |
| 12 | + """ |
| 13 | + CEMExplainer can be used to compute contrastive explanations for image and tabular data. |
| 14 | + This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and |
| 15 | + what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification. |
| 16 | + We use elastic norm regularization to ensure minimality for both parts of the explanation |
| 17 | + i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_ |
| 18 | +
|
| 19 | + References: |
| 20 | + .. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu, |
| 21 | + Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on |
| 22 | + the Missing: Towards Contrastive Explanations with Pertinent Negatives," |
| 23 | + Advances in Neural Information Processing Systems (NeurIPS), 2018. |
| 24 | + <https://arxiv.org/abs/1802.07623>`_ |
| 25 | + """ |
| 26 | + def __init__(self, model): |
| 27 | + |
| 28 | + """ |
| 29 | + Constructor method, initializes the explainer |
| 30 | +
|
| 31 | + Args: |
| 32 | + model: KerasClassifier model whose predictions needs to be explained |
| 33 | + """ |
| 34 | + super(CEMExplainer, self).__init__() |
| 35 | + self._wbmodel = model |
| 36 | + |
| 37 | + |
| 38 | + def set_params(self, *argv, **kwargs): |
| 39 | + """ |
| 40 | + Set parameters for the explainer. |
| 41 | + """ |
| 42 | + pass |
| 43 | + |
| 44 | + |
| 45 | + def explain_instance(self, input_X, |
| 46 | + arg_mode, AE_model, arg_kappa, arg_b, |
| 47 | + arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha=0, arg_threshold=1, arg_offset=0): |
| 48 | + |
| 49 | + """ |
| 50 | + Explains an input instance input_X and returns contrastive explanations. |
| 51 | + Note that this assumes that the classifier was trained with inputs normalized in [-arg_offset, arg_offset] range, where arg_offset is 0 or .5. |
| 52 | +
|
| 53 | + Args: |
| 54 | + input_X (numpy.ndarray): input instance to be explained |
| 55 | + arg_mode (str): 'PP' or 'PN' |
| 56 | + AE_model: Auto-encoder model |
| 57 | + arg_kappa (double): Confidence gap between desired class and other classes |
| 58 | + arg_b (double): Number of different weightings of loss function to try |
| 59 | + arg_max_iter (int): For each weighting of loss function number of iterations to search |
| 60 | + arg_init_const (double): Initial weighting of loss function |
| 61 | + arg_beta (double): Weighting of L1 loss |
| 62 | + arg_gamma (double): Weighting of auto-encoder |
| 63 | + arg_alpha (double): Weighting of L2 loss |
| 64 | + arg_threshold (double): automatically turn off all features less than arg_threshold since nothing to turn off |
| 65 | + arg_offset (double): input_X is in [0,1]. we subtract offset when passed to classifier |
| 66 | +
|
| 67 | + Returns: |
| 68 | + tuple: |
| 69 | + * **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN |
| 70 | + * **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance |
| 71 | + * **INFO** (`str`) -- Other information about PP/PN |
| 72 | + """ |
| 73 | + |
| 74 | + random.seed(121) |
| 75 | + np.random.seed(1211) |
| 76 | + |
| 77 | + (_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X) |
| 78 | + target_label = orig_class |
| 79 | + |
| 80 | + target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]]) |
| 81 | + |
| 82 | + # Hard coding batch_size=1 |
| 83 | + batch_size = 1 |
| 84 | + |
| 85 | + # Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns) |
| 86 | + shape = input_X.shape |
| 87 | + |
| 88 | + attack = AEADEN(self._wbmodel, shape, |
| 89 | + mode=arg_mode, AE=AE_model, batch_size=batch_size, |
| 90 | + kappa=arg_kappa, init_learning_rate=1e-2, |
| 91 | + binary_search_steps=arg_b, max_iterations=arg_max_iter, |
| 92 | + initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma, |
| 93 | + alpha=arg_alpha, threshold=arg_threshold, offset=arg_offset) |
| 94 | + |
| 95 | + |
| 96 | + self._wbmodel.predict(input_X) # helps compile |
| 97 | + adv_X = attack.attack(input_X + arg_offset, target) |
| 98 | + |
| 99 | + adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X) |
| 100 | + |
| 101 | + delta_X = (input_X + arg_offset) - adv_X - arg_offset # add 0.5 to input for attack but subtract 0.5 to get back to [-0.5, 0.5] |
| 102 | + |
| 103 | + adv_X = adv_X - arg_offset # subtrack arg_offset to get it back to [-arg_offset, arg_offset] |
| 104 | + |
| 105 | + _, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X) |
| 106 | + |
| 107 | + INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format( |
| 108 | + arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str) |
| 109 | + |
| 110 | + return (adv_X, delta_X, INFO) |
0 commit comments