@@ -33,22 +33,22 @@ class Infidelity(InterpreterEvaluator):
3333 pixels; while the difference (the latter term) should also be large because the model depends on important pixels
3434 to make decisions. Like this, large values would be offset by large values if the explanation is faithful to the
3535 model. Otherwise, for uniform explanations (all being constant), the former term would be a constant value and the
36- infidelity would become large。
36+ infidelity would become large.
3737
3838 More details about the measure can be found in the original paper: https://arxiv.org/abs/1901.09392.
3939 """
4040 def __init__ (self ,
41- paddle_model : callable ,
41+ model : callable ,
4242 device : str = 'gpu:0' ,
4343 ** kwargs ):
4444 """
4545
4646 Args:
47- paddle_model (callable): _description_
47+ model (callable): _description_
4848 device (_type_, optional): _description_. Defaults to 'gpu:0'.
4949 """
5050
51- super ().__init__ (paddle_model , device , None , ** kwargs )
51+ super ().__init__ (model , device , ** kwargs )
5252 self .results = {}
5353
5454 def _build_predict_fn (self , rebuild : bool = False ):
@@ -74,7 +74,7 @@ def _build_predict_fn(self, rebuild: bool = False):
7474 paddle .set_device (self .device )
7575
7676 # to get gradients, the ``train`` mode must be set.
77- self .paddle_model .eval ()
77+ self .model .eval ()
7878
7979 def predict_fn (data ):
8080 """predict_fn for input gradients based interpreters,
@@ -91,7 +91,7 @@ def predict_fn(data):
9191 with paddle .no_grad ():
9292 # Follow the `official implementation <https://github.com/chihkuanyeh/saliency_evaluation>`_
9393 # to use logits as output.
94- logits = self .paddle_model (paddle .to_tensor (data )) # get logits, [bs, num_c]
94+ logits = self .model (paddle .to_tensor (data )) # get logits, [bs, num_c]
9595 # probas = paddle.nn.functional.softmax(logits, axis=1) # get probabilities.
9696 return logits .numpy ()
9797
@@ -234,34 +234,92 @@ def evaluate(self,
234234
235235
236236class InfidelityNLP (InterpreterEvaluator ):
237- def __init__ (self , paddle_model : callable or None , device : str = 'gpu:0' , ** kwargs ):
238- super ().__init__ (paddle_model , device , ** kwargs )
237+ def __init__ (self , model : callable or None , device : str = 'gpu:0' , ** kwargs ):
238+ super ().__init__ (model , device , ** kwargs )
239239 self .results = {}
240240
241- def _generate_samples (self , input_ids , masked_id = 0 ):
241+ def _generate_samples (self , input_ids , masked_id : int , is_random_samples : bool ):
242242 num_tokens = len (input_ids )
243243
244- # like 1d-conv, stride=1, kernel-size={1,2,3,4,5}
245- generated_samples = []
246- input_ids_array = np .array ([input_ids ])
247- for ks in range (1 , 6 ):
248- if ks > num_tokens - 2 :
249- break
250- for i in range (1 , num_tokens - ks ):
251- tmp = np .copy (input_ids_array )
252- tmp [0 , i :i + ks ] = masked_id
253- generated_samples .append (tmp )
254-
255- perturbed_samples = np .concatenate (generated_samples , axis = 0 )
256- Is = perturbed_samples != input_ids_array
257-
258- return perturbed_samples , Is
259-
260- def evaluate (self , raw_text : str , explanation : list or np .ndarray , tokenizer : callable , recompute : bool = False ):
244+ if is_random_samples :
245+ # This is more suitable for long documents.
246+ # we concat three kinds of perturbations:
247+ # randomly perturbing 1%, 2%, 3%, 4% or 5% tokens respectively
248+ # with 40 times
249+ num_repeats = 40
250+ results = []
251+ ids_array = np .array ([input_ids ]* num_repeats )
252+ for p in range (1 , 6 ):
253+ _k = int (num_tokens * p / 100 )
254+
255+ # not choose from {0, -1}, i.e., [CLS] and [SEP]
256+ # https://stackoverflow.com/a/53893160/4834515
257+ pert_k = np .random .rand (num_repeats , num_tokens - 2 ).argpartition (_k , axis = 1 )[:,:_k ] + 1
258+
259+ pert_array = np .copy (ids_array )
260+ # vectorized slicing.
261+ # https://stackoverflow.com/a/74024396/4834515
262+ row_indexes = np .arange (num_repeats )[:, None ]
263+ pert_array [row_indexes , pert_k ] = masked_id
264+
265+ results .append (pert_array )
266+
267+ perturbed_samples = np .concatenate (results ) # [200, num_tokens]
268+ Is = perturbed_samples != np .array ([input_ids ]) # [200, num_tokens]
269+
270+ return perturbed_samples , Is
271+ else :
272+ # This is more suitable for short documents.
273+ # like 1d-conv, stride=1, kernel-size={1,2,3,4,5}
274+ generated_samples = []
275+ input_ids_array = np .array ([input_ids ])
276+ for ks in range (1 , 6 ):
277+ if ks > num_tokens - 2 :
278+ break
279+ for i in range (1 , num_tokens - ks ):
280+ tmp = np .copy (input_ids_array )
281+ tmp [0 , i :i + ks ] = masked_id
282+ generated_samples .append (tmp )
283+
284+ perturbed_samples = np .concatenate (generated_samples , axis = 0 )
285+ Is = perturbed_samples != input_ids_array
286+
287+ return perturbed_samples , Is
288+
289+ # def _generate_samples(self, input_ids, masked_id=0):
290+ # num_tokens = len(input_ids)
291+
292+ # # we concat three kinds of perturbations:
293+ # # randomly perturbing 1, 2 or 3 tokens respectively
294+ # # with 33 times
295+ # num_repeats = 33
296+
297+ # ids_array = np.array([input_ids]*num_repeats)
298+
299+ # # not choose from {0, -1}, [CLS] and [SEP]
300+ # # https://stackoverflow.com/a/53893160/4834515
301+ # pert_1 = np.random.rand(num_repeats, num_tokens-2).argpartition(1, axis=1)[:,:1] + 1
302+ # pert_2 = np.random.rand(num_repeats, num_tokens-2).argpartition(2, axis=1)[:,:2] + 1
303+ # pert_3 = np.random.rand(num_repeats, num_tokens-2).argpartition(3, axis=1)[:,:3] + 1
304+
305+ # pert_1_array = np.copy(ids_array)
306+ # pert_2_array = np.copy(ids_array)
307+ # pert_3_array = np.copy(ids_array)
308+
309+ # # https://stackoverflow.com/a/74024396/4834515
310+ # row_indexes = np.arange(num_repeats)[:, None]
311+ # pert_1_array[row_indexes, pert_1] = masked_id
312+ # pert_2_array[row_indexes, pert_2] = masked_id
313+ # pert_3_array[row_indexes, pert_3] = masked_id
314+
315+ # perturbed_samples = np.concatenate([pert_1_array, pert_2_array, pert_3_array])
316+ # return perturbed_samples, perturbed_samples != ids_array
317+
318+ def evaluate (self , raw_text : str , explanation : list or np .ndarray , tokenizer : callable , max_seq_len = 128 , is_random_samples = False , recompute : bool = False ):
261319 self ._build_predict_fn ()
262320
263321 # tokenizer text to ids
264- encoded_inputs = tokenizer (raw_text , max_seq_len = 128 )
322+ encoded_inputs = tokenizer (raw_text , max_seq_len = max_seq_len )
265323 # order is important. *_batched_and_to_tuple will be the input for the model.
266324 _batched_and_to_tuple = tuple ([np .array ([v ]) for v in encoded_inputs .values ()])
267325
@@ -276,7 +334,7 @@ def evaluate(self, raw_text: str, explanation: list or np.ndarray, tokenizer: ca
276334 # generate perturbation samples.
277335 if 'proba_diff' not in self .results or recompute :
278336 ## x and I related.
279- generated_samples , Is = self ._generate_samples (encoded_inputs ['input_ids' ], tokenizer .pad_token_id )
337+ generated_samples , Is = self ._generate_samples (encoded_inputs ['input_ids' ], tokenizer .pad_token_id , is_random_samples )
280338 self .results ['generated_samples' ] = generated_samples
281339 self .results ['Is' ] = Is
282340 proba_pert = self .predict_fn (generated_samples )[:, label ]
0 commit comments