Skip to content

Commit 2fd7454

Browse files
author
Beat Buesser
committed
Update Keras imports
Signed-off-by: Beat Buesser <[email protected]>
1 parent aa8fe22 commit 2fd7454

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tests/metrics/test_metrics.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,18 @@ def test_loss_sensitivity(self):
9292
@staticmethod
9393
def _cnn_mnist_k(input_shape):
9494
import tensorflow as tf
95-
import keras
96-
from keras.models import Sequential
97-
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
95+
tf_version = [int(v) for v in tf.__version__.split(".")]
96+
if tf_version[0] == 2 and tf_version[1] >= 3:
97+
is_tf23_keras24 = True
98+
tf.compat.v1.disable_eager_execution()
99+
from tensorflow import keras
100+
from tensorflow.keras.models import Sequential
101+
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
102+
else:
103+
is_tf23_keras24 = False
104+
import keras
105+
from keras.models import Sequential
106+
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
98107

99108
# Create simple CNN
100109
model = Sequential()

0 commit comments

Comments
 (0)