Skip to content

Commit e83f82a

Browse files
update check network
1 parent 8030027 commit e83f82a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

adapt/feature_based/_msda.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,22 @@ def predict(self, X):
287287
Prediction of ``estimator_``.
288288
"""
289289
X = check_one_array(X)
290-
return self.estimator_.predict(self.encoder_.predict(X))
290+
return self.estimator_.predict(self.predict_features(X))
291+
292+
293+
def predict_features(self, X):
294+
"""
295+
Return the encoded features of X.
296+
297+
Parameters
298+
----------
299+
X: array
300+
input data
301+
302+
Returns
303+
-------
304+
X_enc: array
305+
predictions of encoder network
306+
"""
307+
X = check_one_array(X)
308+
return self.encoder_.predict(X)

adapt/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def check_network(network, copy=True,
202202
if copy:
203203
try:
204204
if hasattr(network, "input_shape"):
205-
new_network = clone_model(network, input_tensors=network.input)
205+
shape = network.input_shape[1:]
206+
new_network = clone_model(network, input_tensors=Input(shape))
206207
new_network.set_weights(network.get_weights())
207208
else:
208209
new_network = clone_model(network)

0 commit comments

Comments
 (0)