Skip to content

Commit 7b2a367

Browse files
committed
transformer interpreters update
1 parent 411b322 commit 7b2a367

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

interpretdl/interpreter/abc_interpreter.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,15 @@ def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, *
385385
"paddle_model has to be " \
386386
"an instance of paddle.nn.Layer or a compatible one."
387387

388-
def _build_predict_fn(self, rebuild: bool = False, embedding_name: str or None = None, attn_map_name: str or None = None,
389-
attn_v_name: str or None = None, attn_proj_name: str or None = None, nlp: bool = False):
388+
def _build_predict_fn(
389+
self,
390+
rebuild: bool = False,
391+
embedding_name: str or None = None,
392+
attn_map_name: str or None = None,
393+
attn_v_name: str or None = None,
394+
attn_proj_name: str or None = None,
395+
gradient_of: str or None = None,
396+
nlp: bool = False):
390397

391398
"""Build ``predict_fn`` for transformer based algorithms.
392399
The model is supposed to be a classification model.
@@ -463,7 +470,7 @@ def block_value_hook(layer, input, output):
463470
if attn_map_name is not None and re.match(attn_map_name, n):
464471
h = v.register_forward_post_hook(block_attn_hook)
465472
hooks.append(h)
466-
elif scale is not None and re.match(embedding_name, n):
473+
elif scale is not None and embedding_name is not None and re.match(embedding_name, n):
467474
h = v.register_forward_post_hook(hook)
468475
hooks.append(h)
469476
elif attn_proj_name is not None and re.match(attn_proj_name, n):
@@ -474,21 +481,28 @@ def block_value_hook(layer, input, output):
474481
h = v.register_forward_post_hook(block_value_hook)
475482
hooks.append(h)
476483

477-
out = self.paddle_model(*inputs)
484+
logits = self.paddle_model(*inputs)
478485

479486
for h in hooks:
480487
h.remove()
481488

482-
proba = paddle.nn.functional.softmax(out, axis=1)
489+
proba = paddle.nn.functional.softmax(logits, axis=1)
483490
preds = paddle.argmax(proba, axis=1)
484491
if label is None:
485492
label = preds.numpy()
493+
label_onehot = paddle.nn.functional.one_hot(paddle.to_tensor(label), num_classes=logits.shape[1])
486494

487495
block_attns_grads = []
488-
489-
label_onehot = paddle.nn.functional.one_hot(paddle.to_tensor(label), num_classes=proba.shape[1])
490-
target = paddle.sum(proba * label_onehot, axis=1)
491-
target.backward()
496+
497+
if gradient_of == 'probability' or gradient_of is None:
498+
target = paddle.sum(proba * label_onehot, axis=1)
499+
target.backward()
500+
elif gradient_of == 'logit':
501+
target = paddle.sum(logits * label_onehot, axis=1)
502+
target.backward()
503+
else:
504+
raise ValueError("`gradient_of` should be one of [logits, probability].")
505+
492506
for i, attn in enumerate(block_attns):
493507
grad = attn.grad.numpy()
494508
block_attns_grads.append(grad)

interpretdl/interpreter/bidirectional_transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def interpret(self,
167167
ap_mode: str = "head",
168168
start_layer: int = 11,
169169
steps: int = 20,
170-
embedding_name='^[a-z]*.embeddings.word_embeddings$',
170+
embedding_name='^[a-z]*.embeddings$',
171171
attn_map_name='^[a-z]*.encoder.layers.[0-9]*.self_attn.attn_drop$',
172172
attn_v_name='^[a-z]*.encoder.layers.[0-9]*.self_attn.v_proj$',
173173
attn_proj_name='^[a-z]*.encoder.layers.[0-9]*.self_attn.out_proj$',
@@ -182,7 +182,7 @@ def interpret(self,
182182
start_layer (int, optional): Compute the state from the start layer. Default: ``11``.
183183
steps (int, optional): number of steps in the Riemann approximation of the integral. Default: ``20``.
184184
embedding_name (str, optional): The layer name for embedding, head-wise/token-wise.
185-
Default: ``^ernie.embeddings.word_embeddings$``.
185+
Default: ``^ernie.embeddings$``.
186186
attn_map_name (str, optional): The layer name to obtain the attention weights, head-wise/token-wise.
187187
Default: ``^ernie.encoder.layers.*.self_attn.attn_drop$``.
188188
attn_v_name (str, optional): The layer name for value projection, token-wise.
@@ -216,7 +216,8 @@ def text_to_input_fn(raw_text):
216216
model_input = tuple(model_input, )
217217

218218
self._build_predict_fn(embedding_name=embedding_name, attn_map_name=attn_map_name,
219-
attn_v_name=attn_v_name, attn_proj_name=attn_proj_name, nlp=True)
219+
attn_v_name=attn_v_name, attn_proj_name=attn_proj_name,
220+
gradient_of='logit')
220221

221222
attns, grads, inputs, values, projs, proba, preds = self.predict_fn(model_input)
222223
assert start_layer < len(attns), "start_layer should be in the range of [0, num_block-1]"
@@ -269,6 +270,8 @@ def text_to_input_fn(raw_text):
269270
# intermediate results, for possible further usages.
270271
self.predicted_label = preds
271272
self.predicted_proba = proba
273+
self.ap = R[:, 0, :]
274+
self.rf = grad_head_mean[:, 0, :]
272275

273276
if visual:
274277
# TODO: visualize if tokenizer is given.

interpretdl/interpreter/generic_attention.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def interpret(self,
234234
text_to_input_fn: callable = None,
235235
label: int or None = None,
236236
start_layer: int = 11,
237-
embedding_name='^[a-z]*.embeddings.word_embeddings$',
238237
attn_map_name='^[a-z]*.encoder.layers.[0-9]*.self_attn.attn_drop$',
239238
max_seq_len=128,
240239
visual=False):
@@ -246,8 +245,6 @@ def interpret(self,
246245
label (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels
247246
should be equal to the number of texts. If None, the most likely label for each text will be used.
248247
Default: ``None``.
249-
embedding_name (str, optional): The layer name for word embedding.
250-
Default: ``^ernie.embeddings.word_embeddings$``.
251248
attn_map_name (str, optional): The layer name to obtain attention weights.
252249
Default: ``^ernie.encoder.layers.*.self_attn.attn_drop$``
253250
@@ -272,7 +269,7 @@ def text_to_input_fn(raw_text):
272269
model_input = tuple(inp for inp in model_input)
273270
else:
274271
model_input = tuple(model_input, )
275-
self._build_predict_fn(embedding_name=embedding_name, attn_map_name=attn_map_name, nlp=True)
272+
self._build_predict_fn(attn_map_name=attn_map_name, gradient_of='logit', nlp=True)
276273

277274
attns, grads, inputs, values, projs, proba, preds = self.predict_fn(model_input)
278275
assert start_layer < len(attns), "start_layer should be in the range of [0, num_block-1]"

0 commit comments

Comments
 (0)