Skip to content

Commit 8030027

Browse files
add same input in check network
1 parent deb482c commit 8030027

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

adapt/feature_based/_deep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ def get_metrics(self, inputs_ys, inputs_yt,
12961296
disc = 0.25 * K.mean(K.square(subtract([cov_src, cov_tgt])))
12971297

12981298
metrics["task_s"] = K.mean(task_s)
1299-
metrics["disc"] = K.mean(disc)
1299+
metrics["disc"] = self.lambda_ * K.mean(disc)
13001300
if inputs_yt is not None:
13011301
task_t = self.loss_(inputs_yt, task_tgt)
13021302
metrics["task_t"] = K.mean(task_t)

adapt/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,7 @@ def check_network(network, copy=True,
202202
if copy:
203203
try:
204204
if hasattr(network, "input_shape"):
205-
inputs = Input(network.input_shape[1:])
206-
new_network = clone_model(network, input_tensors=inputs)
205+
new_network = clone_model(network, input_tensors=network.input)
207206
new_network.set_weights(network.get_weights())
208207
else:
209208
new_network = clone_model(network)

0 commit comments

Comments
 (0)