Skip to content

Commit bac69a9

Browse files
update compile in check network
1 parent 7fa7708 commit bac69a9

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

adapt/parameter_based/_regular.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import tensorflow.keras.backend as K
1515
import tensorflow as tf
1616
from tensorflow.keras.optimizers import Adam
17+
from tensorflow.keras import Sequential
18+
from tensorflow.keras.layers import Flatten, Dense
1719

1820
from adapt.utils import (check_arrays,
1921
check_one_array,

adapt/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,22 @@ def check_network(network, copy=True,
226226
new_network = network
227227
if compile_:
228228
if network.optimizer:
229-
new_network.compile(optimizer=deepcopy(network.optimizer),
230-
loss=deepcopy(network.loss),
231-
metrics=deepcopy(network.metrics))
229+
# TODO, find a way of giving metrics (for now this
230+
# induces weird behaviour with fitted model having
231+
# their loss in metrics)
232+
try:
233+
# TODO, can we be sure that deepcopy will always work?
234+
try:
235+
optimizer=deepcopy(network.optimizer)
236+
loss=deepcopy(network.loss)
237+
new_network.compile(optimizer=optimizer,
238+
loss=loss)
239+
except:
240+
new_network.compile(optimizer=network.optimizer,
241+
loss=network.loss)
242+
except:
243+
raise ValueError("Unable to compile the given `%s` argument."%
244+
(display_name))
232245
else:
233246
raise ValueError("The given `%s` argument is not compiled yet. "
234247
"Please use `model.compile(optimizer, loss)`."%

0 commit comments

Comments
 (0)