Skip to content

Commit deb482c

Browse files
add Input tensor in check network
1 parent 5530e07 commit deb482c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

adapt/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
1515
import tensorflow as tf
1616
from tensorflow.keras import Sequential, Model
17-
from tensorflow.keras.layers import Layer, Dense, Flatten
17+
from tensorflow.keras.layers import Layer, Dense, Flatten, Input
1818
from tensorflow.keras.models import clone_model
1919

2020

@@ -201,10 +201,12 @@ def check_network(network, copy=True,
201201

202202
if copy:
203203
try:
204-
new_network = clone_model(network)
205204
if hasattr(network, "input_shape"):
206-
new_network.build(input_shape=network.input_shape)
205+
inputs = Input(network.input_shape[1:])
206+
new_network = clone_model(network, input_tensors=inputs)
207207
new_network.set_weights(network.get_weights())
208+
else:
209+
new_network = clone_model(network)
208210
except Exception as e:
209211
if force_copy:
210212
raise ValueError("`%s` argument can't be duplicated. "

0 commit comments

Comments
 (0)