@@ -22,15 +22,15 @@ class GAInterpreter(InputGradientInterpreter):
2222
2323 """
2424
25- def __init__ (self , paddle_model : callable , device : str = 'gpu:0' ) -> None :
25+ def __init__ (self , model : callable , device : str = 'gpu:0' ) -> None :
2626 """
2727
2828 Args:
29- paddle_model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
30- device (str): The device used for running ``paddle_model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
29+ model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
30+ device (str): The device used for running ``model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
3131 etc.
3232 """
33- Interpreter .__init__ (self , paddle_model , device )
33+ Interpreter .__init__ (self , model , device )
3434
3535 def interpret (self ,
3636 image_input : str or np .ndarray ,
@@ -140,7 +140,7 @@ def _build_predict_fn(self, vis_attn_layer_pattern: str, txt_attn_layer_pattern:
140140
141141 if self .predict_fn is None or rebuild :
142142 import paddle
143- self ._paddle_env_setup () # inherit from InputGradientInterpreter
143+ self ._env_setup () # inherit from InputGradientInterpreter
144144
145145 def predict_fn (image , text_tokenized ):
146146 image = paddle .to_tensor (image )
@@ -158,7 +158,7 @@ def txt_hook(layer, input, output):
158158 txt_attns .append (output )
159159
160160 hooks = [] # for remove.
161- for n , v in self .paddle_model .named_sublayers ():
161+ for n , v in self .model .named_sublayers ():
162162 if re .match (vis_attn_layer_pattern , n ):
163163 h = v .register_forward_post_hook (img_hook )
164164 hooks .append (h )
@@ -167,7 +167,7 @@ def txt_hook(layer, input, output):
167167 h = v .register_forward_post_hook (txt_hook )
168168 hooks .append (h )
169169
170- logits_per_image , logits_per_text = self .paddle_model (image , text_tokenized )
170+ logits_per_image , logits_per_text = self .model (image , text_tokenized )
171171
172172 for h in hooks :
173173 h .remove ()
@@ -180,7 +180,7 @@ def txt_hook(layer, input, output):
180180 one_hot [paddle .arange (logits_per_image .shape [0 ]), index ] = 1
181181 one_hot = paddle .to_tensor (one_hot )
182182 one_hot = paddle .sum (one_hot * logits_per_image )
183- self .paddle_model .clear_gradients ()
183+ self .model .clear_gradients ()
184184 one_hot .backward ()
185185
186186 img_attns_grads = []
@@ -218,15 +218,15 @@ class GANLPInterpreter(TransformerInterpreter):
218218
219219 """
220220
221- def __init__ (self , paddle_model : callable , device : str = 'gpu:0' , use_cuda = None ) -> None :
221+ def __init__ (self , model : callable , device : str = 'gpu:0' ) -> None :
222222 """
223223
224224 Args:
225- paddle_model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
226- device (str): The device used for running ``paddle_model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
225+ model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
226+ device (str): The device used for running ``model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
227227 etc.
228228 """
229- TransformerInterpreter .__init__ (self , paddle_model , device , use_cuda )
229+ TransformerInterpreter .__init__ (self , model , device )
230230
231231 def interpret (self ,
232232 raw_text : str ,
@@ -315,15 +315,15 @@ class GACVInterpreter(TransformerInterpreter):
315315 The following implementation is specially designed for Vision Transformer.
316316 """
317317
318- def __init__ (self , paddle_model : callable , device : str = 'gpu:0' , use_cuda = None ) -> None :
318+ def __init__ (self , model : callable , device : str = 'gpu:0' ) -> None :
319319 """
320320
321321 Args:
322- paddle_model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
323- device (str): The device used for running ``paddle_model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
322+ model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
323+ device (str): The device used for running ``model ``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
324324 etc.
325325 """
326- TransformerInterpreter .__init__ (self , paddle_model , device , use_cuda )
326+ TransformerInterpreter .__init__ (self , model , device )
327327
328328 def interpret (self ,
329329 inputs : str or list (str ) or np .ndarray ,
@@ -381,7 +381,7 @@ def interpret(self,
381381
382382 R = R + np .matmul (attn , R )
383383
384- if hasattr (self .paddle_model , 'global_pool' ) and self .paddle_model .global_pool :
384+ if hasattr (self .model , 'global_pool' ) and self .model .global_pool :
385385 # For MAE ViT, but GA does not work well.
386386 R = R [:, 1 :, :].mean (axis = 1 )
387387 else :
0 commit comments