@@ -205,10 +205,16 @@ def get_gradient(x, grad):
205205
206206 return x_grad
207207
208- if x .shape == grad .shape :
208+ if x .dtype == np .object :
209+ x_grad_list = list ()
210+ for i , x_i in enumerate (x ):
211+ x_grad_list .append (get_gradient (x = x_i , grad = grad [i ]))
212+ x_grad = np .empty (x .shape [0 ], dtype = object )
213+ x_grad [:] = list (x_grad_list )
214+ elif x .shape == grad .shape :
209215 x_grad = get_gradient (x = x , grad = grad )
210216 else :
211- # Special case for lass gradients
217+ # Special case for loss gradients
212218 x_grad = np .zeros_like (grad )
213219 for i in range (grad .shape [1 ]):
214220 x_grad [:, i , ...] = get_gradient (x = x , grad = grad [:, i , ...])
@@ -268,35 +274,41 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
268274 return result , y
269275
270276 # Backward compatibility.
271- def _get_gradient (self , x : np .ndarray , grad : np .ndarray ) -> np .ndarray :
272- """
273- Helper function for estimate_gradient
274- """
277+ def estimate_gradient (self , x : np .ndarray , grad : np .ndarray ) -> np .ndarray :
275278 import tensorflow as tf # lgtm [py/repeated-import]
276279
277- with tf . GradientTape () as tape :
278- x = tf . convert_to_tensor ( x , dtype = config . ART_NUMPY_DTYPE )
279- tape . watch ( x )
280- grad = tf . convert_to_tensor ( grad , dtype = config . ART_NUMPY_DTYPE )
280+ def get_gradient ( x : np . ndarray , grad : np . ndarray ) -> np . ndarray :
281+ """
282+ Helper function for estimate_gradient
283+ """
281284
282- x_prime = self .estimate_forward (x )
285+ with tf .GradientTape () as tape :
286+ x = tf .convert_to_tensor (x , dtype = config .ART_NUMPY_DTYPE )
287+ tape .watch (x )
288+ grad = tf .convert_to_tensor (grad , dtype = config .ART_NUMPY_DTYPE )
283289
284- x_grad = tape . gradient ( target = x_prime , sources = x , output_gradients = grad )
290+ x_prime = self . estimate_forward ( x )
285291
286- x_grad = x_grad .numpy ()
287- if x_grad .shape != x .shape :
288- raise ValueError ("The input shape is {} while the gradient shape is {}" .format (x .shape , x_grad .shape ))
292+ x_grad = tape .gradient (target = x_prime , sources = x , output_gradients = grad )
289293
290- return x_grad
294+ x_grad = x_grad .numpy ()
295+ if x_grad .shape != x .shape :
296+ raise ValueError ("The input shape is {} while the gradient shape is {}" .format (x .shape , x_grad .shape ))
291297
292- # Backward compatibility.
293- def estimate_gradient (self , x : np .ndarray , grad : np .ndarray ) -> np .ndarray :
294- if x .shape == grad .shape :
295- x_grad = self ._get_gradient (x = x , grad = grad )
298+ return x_grad
299+
300+ if x .dtype == np .object :
301+ x_grad_list = list ()
302+ for i , x_i in enumerate (x ):
303+ x_grad_list .append (get_gradient (x = x_i , grad = grad [i ]))
304+ x_grad = np .empty (x .shape [0 ], dtype = object )
305+ x_grad [:] = list (x_grad_list )
306+ elif x .shape == grad .shape :
307+ x_grad = get_gradient (x = x , grad = grad )
296308 else :
297- # Special case for lass gradients
309+ # Special case for loss gradients
298310 x_grad = np .zeros_like (grad )
299311 for i in range (grad .shape [1 ]):
300- x_grad [:, i , ...] = self . _get_gradient (x = x , grad = grad [:, i , ...])
312+ x_grad [:, i , ...] = get_gradient (x = x , grad = grad [:, i , ...])
301313
302314 return x_grad
0 commit comments