Skip to content

Commit 955957f

Browse files
committed
Fix error with model saving introduced with new optimizers in tf 2.11.0. Import Adam optimizer from tf.keras.optimizers.legacy if possible - on import error attempt to user tf.keras.optimizers instead.
1 parent d8d2167 commit 955957f

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

ivis/ivis.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ def _fit(self, X, Y=None, shuffle_mode=True):
271271
triplet_loss_func = triplet_loss(distance=self.distance)
272272

273273
if self.model_ is None:
274+
try:
275+
optimizer_fn = tf.keras.optimizers.legacy.Adam()
276+
except ImportError:
277+
optimizer_fn = tf.keras.optimizers.Adam()
278+
274279
if isinstance(self.model, str):
275280
input_size = (X.shape[-1],)
276281
self.model_, (anchor_embedding, *_) = \
@@ -282,7 +287,7 @@ def _fit(self, X, Y=None, shuffle_mode=True):
282287
embedding_dims=self.embedding_dims)
283288

284289
if Y is None:
285-
self.model_.compile(optimizer='adam', loss=triplet_loss_func)
290+
self.model_.compile(optimizer=optimizer_fn, loss=triplet_loss_func)
286291
else:
287292
supervised_layer = build_supervised_layer(self.supervision_metric,
288293
Y, name='supervised')
@@ -297,7 +302,7 @@ def _fit(self, X, Y=None, shuffle_mode=True):
297302
outputs=[self.model_.output,
298303
supervised_out])
299304
self.model_.compile(
300-
optimizer='adam',
305+
optimizer=optimizer_fn,
301306
loss={
302307
'stacked_triplets': triplet_loss_func,
303308
'supervised': supervised_loss

0 commit comments

Comments
 (0)