@@ -97,6 +97,8 @@ def __init__(
9797 layer with this index will be considered for computing gradients. For models with only one
9898 output layer this values is not required.
9999 """
100+ import tensorflow as tf
101+
100102 super ().__init__ (
101103 model = model ,
102104 clip_values = clip_values ,
@@ -130,6 +132,12 @@ def __init__(
130132 self ._input_shape = tuple (self ._input .shape [1 :])
131133 self ._layer_names = self ._get_layers ()
132134
135+ @tf .function (reduce_retracing = True ) # Compile this for speed
136+ def _forward_pass (model , x , training , batch_size ):
137+ return model (x , training = training , batch_size = batch_size , verbose = False )
138+
139+ self ._forward_pass = _forward_pass
140+
133141 @property
134142 def input_shape (self ) -> tuple [int , ...]:
135143 """
@@ -397,15 +405,14 @@ def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = Fa
397405 x_preprocessed , _ = self ._apply_preprocessing (x , y = None , fit = False )
398406
399407 # Run predictions with batching
400- if training_mode :
401- predictions = self ._model (x_preprocessed , training = training_mode , verbose = False )
402- else :
403- predictions = self ._model .predict (x_preprocessed , batch_size = batch_size , verbose = False )
408+ predictions = self ._forward_pass (
409+ self ._model , x_preprocessed , training = training_mode , batch_size = batch_size
410+ ) # Fast, compiled call
404411
405412 # Apply postprocessing
406- predictions = self ._apply_postprocessing (preds = predictions , fit = False )
413+ predictions_post = self ._apply_postprocessing (preds = predictions . numpy () , fit = False )
407414
408- return predictions
415+ return predictions_post
409416
410417 def fit (
411418 self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , verbose : bool = False , ** kwargs
0 commit comments