@@ -385,8 +385,15 @@ def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, *
385385 "paddle_model has to be " \
386386 "an instance of paddle.nn.Layer or a compatible one."
387387
388- def _build_predict_fn (self , rebuild : bool = False , embedding_name : str or None = None , attn_map_name : str or None = None ,
389- attn_v_name : str or None = None , attn_proj_name : str or None = None , nlp : bool = False ):
388+ def _build_predict_fn (
389+ self ,
390+ rebuild : bool = False ,
391+ embedding_name : str or None = None ,
392+ attn_map_name : str or None = None ,
393+ attn_v_name : str or None = None ,
394+ attn_proj_name : str or None = None ,
395+ gradient_of : str or None = None ,
396+ nlp : bool = False ):
390397
391398 """Build ``predict_fn`` for transformer based algorithms.
392399 The model is supposed to be a classification model.
@@ -463,7 +470,7 @@ def block_value_hook(layer, input, output):
463470 if attn_map_name is not None and re .match (attn_map_name , n ):
464471 h = v .register_forward_post_hook (block_attn_hook )
465472 hooks .append (h )
466- elif scale is not None and re .match (embedding_name , n ):
473+ elif scale is not None and embedding_name is not None and re .match (embedding_name , n ):
467474 h = v .register_forward_post_hook (hook )
468475 hooks .append (h )
469476 elif attn_proj_name is not None and re .match (attn_proj_name , n ):
@@ -474,21 +481,28 @@ def block_value_hook(layer, input, output):
474481 h = v .register_forward_post_hook (block_value_hook )
475482 hooks .append (h )
476483
477- out = self .paddle_model (* inputs )
484+ logits = self .paddle_model (* inputs )
478485
479486 for h in hooks :
480487 h .remove ()
481488
482- proba = paddle .nn .functional .softmax (out , axis = 1 )
489+ proba = paddle .nn .functional .softmax (logits , axis = 1 )
483490 preds = paddle .argmax (proba , axis = 1 )
484491 if label is None :
485492 label = preds .numpy ()
493+ label_onehot = paddle .nn .functional .one_hot (paddle .to_tensor (label ), num_classes = logits .shape [1 ])
486494
487495 block_attns_grads = []
488-
489- label_onehot = paddle .nn .functional .one_hot (paddle .to_tensor (label ), num_classes = proba .shape [1 ])
490- target = paddle .sum (proba * label_onehot , axis = 1 )
491- target .backward ()
496+
497+ if gradient_of == 'probability' or gradient_of is None :
498+ target = paddle .sum (proba * label_onehot , axis = 1 )
499+ target .backward ()
500+ elif gradient_of == 'logit' :
501+ target = paddle .sum (logits * label_onehot , axis = 1 )
502+ target .backward ()
503+ else :
504+ raise ValueError ("`gradient_of` should be one of [logits, probability]." )
505+
492506 for i , attn in enumerate (block_attns ):
493507 grad = attn .grad .numpy ()
494508 block_attns_grads .append (grad )
0 commit comments