Skip to content

Commit 24ac7df

Browse files
authored
Fix Keras Import for Custom Loss Gradient
This PR fixes an issue with inputs when using custom loss gradients.
1 parent c15ea16 commit 24ac7df

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

art/estimators/classification/keras.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,10 @@ def custom_loss_gradient(self, nn_function, tensors, input_values, name="default
630630
:return: the gradient of the function w.r.t vars
631631
:rtype: `np.ndarray`
632632
"""
633-
import keras.backend as k
633+
if self.is_tensorflow:
634+
import tensorflow.keras.backend as k
635+
else:
636+
import keras.backend as k
634637

635638
if not hasattr(self, "_custom_loss_func"):
636639
self._custom_loss_func = {}

0 commit comments

Comments
 (0)