Skip to content

Commit 1b21fde

Browse files
updates tests
1 parent 43149e2 commit 1b21fde

File tree

15 files changed

+113
-102
lines changed

15 files changed

+113
-102
lines changed

adapt/feature_based/_deepcoral.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,7 @@ def train_step(self, data):
190190
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
191191

192192
# Update metrics
193-
self.compiled_metrics.update_state(ys, ys_pred)
194-
self.compiled_loss(ys, ys_pred)
195-
# Return a dict mapping metric names to current value
196-
logs = {m.name: m.result() for m in self.metrics}
193+
logs = self._update_logs(ys, ys_pred)
197194
logs.update({"disc_loss": disc_loss})
198195
return logs
199196

adapt/feature_based/_mcd.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,12 @@ def pretrain_step(self, data):
122122
gradients_disc = disc_tape.gradient(disc_loss, trainable_vars_disc)
123123

124124
# Update weights
125-
self.optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
126-
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
127-
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
125+
self.pretrain_optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
126+
self.pretrain_optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
127+
self.pretrain_optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
128128

129129
# Update metrics
130-
self.compiled_metrics.update_state(ys, ys_pred)
131-
self.compiled_loss(ys, ys_pred)
132-
# Return a dict mapping metric names to current value
133-
logs = {m.name: m.result() for m in self.metrics}
130+
logs = self._update_logs(ys, ys_pred)
134131
return logs
135132

136133

@@ -162,7 +159,7 @@ def train_step(self, data):
162159
# Compute gradients
163160
trainable_vars_enc = self.encoder_.trainable_variables
164161
gradients_enc = enc_tape.gradient(enc_loss, trainable_vars_enc)
165-
self.optimizer.apply_gradients(zip(gradients_enc, trainable_vars_enc))
162+
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
166163

167164
# loss
168165
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:
@@ -212,10 +209,7 @@ def train_step(self, data):
212209
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
213210

214211
# Update metrics
215-
self.compiled_metrics.update_state(ys, ys_pred)
216-
self.compiled_loss(ys, ys_pred)
217-
# Return a dict mapping metric names to current value
218-
logs = {m.name: m.result() for m in self.metrics}
212+
logs = self._update_logs(ys, ys_pred)
219213
logs.update({"disc_loss": discrepancy})
220214
return logs
221215

@@ -264,12 +258,7 @@ def _initialize_networks(self):
264258

265259

266260
def _initialize_weights(self, shape_X):
267-
# Init weights encoder
268-
self(np.zeros((1,) + shape_X))
269-
X_enc = self.encoder_(np.zeros((1,) + shape_X))
270-
self.task_(X_enc)
271-
self.discriminator_(X_enc)
272-
261+
super()._initialize_weights(shape_X)
273262
# Add noise to discriminator in order to
274263
# differentiate from task
275264
weights = self.discriminator_.get_weights()

adapt/feature_based/_mdd.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,7 @@ def train_step(self, data):
157157
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
158158

159159
# Update metrics
160-
self.compiled_metrics.update_state(ys, ys_pred)
161-
self.compiled_loss(ys, ys_pred)
162-
# Return a dict mapping metric names to current value
163-
logs = {m.name: m.result() for m in self.metrics}
164-
# disc_metrics = self._get_disc_metrics(ys_disc, yt_disc)
160+
logs = self._update_logs(ys, ys_pred)
165161
logs.update({"disc_loss": disc_loss})
166162
return logs
167163

@@ -189,11 +185,7 @@ def _initialize_networks(self):
189185

190186
def _initialize_weights(self, shape_X):
191187
# Init weights encoder
192-
self(np.zeros((1,) + shape_X))
193-
X_enc = self.encoder_(np.zeros((1,) + shape_X))
194-
self.task_(X_enc)
195-
self.discriminator_(X_enc)
196-
188+
super()._initialize_weights(shape_X)
197189
# Add noise to discriminator in order to
198190
# differentiate from task
199191
weights = self.discriminator_.get_weights()

adapt/instance_based/_iwn.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from adapt.base import BaseAdaptDeep, make_insert_doc
1414
from adapt.utils import (check_arrays, check_network, get_default_task,
15-
set_random_seed, check_estimator, check_sample_weight)
15+
set_random_seed, check_estimator, check_sample_weight, check_if_compiled)
1616

1717
EPS = np.finfo(np.float32).eps
1818

@@ -141,8 +141,21 @@ def _initialize_networks(self):
141141
name="weighter")
142142
self.sigma_ = tf.Variable(self.sigma_init,
143143
trainable=self.update_sigma)
144-
145-
144+
145+
if not hasattr(self, "estimator_"):
146+
self.estimator_ = check_estimator(self.estimator,
147+
copy=self.copy,
148+
force_copy=True)
149+
150+
151+
def _initialize_weights(self, shape_X):
152+
if hasattr(self, "weighter_"):
153+
self.weighter_.build((None,) + shape_X)
154+
self.build((None,) + shape_X)
155+
if isinstance(self.estimator_, Model):
156+
self.estimator_.build((None,) + shape_X)
157+
158+
146159
def pretrain_step(self, data):
147160
# Unpack the data.
148161
Xs, Xt, ys, yt = self._unpack_data(data)
@@ -163,7 +176,7 @@ def pretrain_step(self, data):
163176
gradients = tape.gradient(loss, trainable_vars)
164177

165178
# Update weights
166-
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
179+
self.pretrain_optimizer.apply_gradients(zip(gradients, trainable_vars))
167180

168181
logs = {"loss": loss}
169182
return logs
@@ -200,7 +213,7 @@ def train_step(self, data):
200213

201214
# Update weights
202215
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
203-
self.optimizer.apply_gradients(zip(gradients_sigma, [self.sigma_]))
216+
self.optimizer_sigma.apply_gradients(zip(gradients_sigma, [self.sigma_]))
204217

205218
# Return a dict mapping metric names to current value
206219
logs = {"loss": loss, "sigma": self.sigma_}
@@ -214,6 +227,26 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None,
214227
return self
215228

216229

230+
def compile(self,
231+
optimizer=None,
232+
loss=None,
233+
metrics=None,
234+
loss_weights=None,
235+
weighted_metrics=None,
236+
run_eagerly=None,
237+
steps_per_execution=None,
238+
**kwargs):
239+
super().compile(optimizer=optimizer,
240+
loss=loss,
241+
metrics=metrics,
242+
loss_weights=loss_weights,
243+
weighted_metrics=weighted_metrics,
244+
run_eagerly=run_eagerly,
245+
steps_per_execution=steps_per_execution,
246+
**kwargs)
247+
self.optimizer_sigma = self.optimizer.__class__.from_config(self.optimizer.get_config())
248+
249+
217250
def fit_weights(self, Xs, Xt, **fit_params):
218251
"""
219252
Fit importance weighting.
@@ -276,22 +309,23 @@ def fit_estimator(self, X, y, sample_weight=None,
276309
X, y = check_arrays(X, y, accept_sparse=True)
277310
set_random_seed(random_state)
278311

279-
if (not warm_start) or (not hasattr(self, "estimator_")):
280-
estimator = self.estimator
281-
self.estimator_ = check_estimator(estimator,
312+
if not hasattr(self, "estimator_"):
313+
self.estimator_ = check_estimator(self.estimator,
282314
copy=self.copy,
283315
force_copy=True)
284-
if isinstance(self.estimator_, Model):
285-
compile_params = {}
286-
if estimator._is_compiled:
287-
compile_params["loss"] = deepcopy(estimator.loss)
288-
compile_params["optimizer"] = deepcopy(estimator.optimizer)
289-
else:
290-
raise ValueError("The given `estimator` argument"
291-
" is not compiled yet. "
292-
"Please give a compiled estimator or "
293-
"give a `loss` and `optimizer` arguments.")
294-
self.estimator_.compile(**compile_params)
316+
317+
estimator = self.estimator
318+
if isinstance(self.estimator_, Model):
319+
compile_params = {}
320+
if check_if_compiled(estimator):
321+
compile_params["loss"] = deepcopy(estimator.loss)
322+
compile_params["optimizer"] = deepcopy(estimator.optimizer)
323+
else:
324+
raise ValueError("The given `estimator` argument"
325+
" is not compiled yet. "
326+
"Please give a compiled estimator or "
327+
"give a `loss` and `optimizer` arguments.")
328+
self.estimator_.compile(**compile_params)
295329

296330
fit_args = [
297331
p.name

adapt/instance_based/_wann.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ def _initialize_networks(self):
116116
name="discriminator")
117117

118118

119+
def _initialize_weights(self, shape_X):
120+
if hasattr(self, "weighter_"):
121+
self.weighter_.build((None,) + shape_X)
122+
if hasattr(self, "task_"):
123+
self.task_.build((None,) + shape_X)
124+
if hasattr(self, "discriminator_"):
125+
self.discriminator_.build((None,) + shape_X)
126+
self.build((None,) + shape_X)
127+
128+
119129
def _add_regularization(self, weighter):
120130
for i in range(len(weighter.layers)):
121131
if hasattr(weighter.layers[i], "kernel_constraint"):
@@ -149,7 +159,7 @@ def pretrain_step(self, data):
149159
gradients = tape.gradient(loss, trainable_vars)
150160

151161
# Update weights
152-
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
162+
self.pretrain_optimizer.apply_gradients(zip(gradients, trainable_vars))
153163

154164
logs = {"loss": loss}
155165
return logs
@@ -217,15 +227,33 @@ def train_step(self, data):
217227

218228
# Update weights
219229
self.optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
220-
self.optimizer.apply_gradients(zip(gradients_weight, trainable_vars_weight))
221-
self.optimizer.apply_gradients(zip(gradients_disc, trainable_vars_disc))
230+
self.optimizer_weight.apply_gradients(zip(gradients_weight, trainable_vars_weight))
231+
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
222232

223233
# Update metrics
224-
self.compiled_metrics.update_state(ys, ys_pred)
225-
self.compiled_loss(ys, ys_pred)
226-
# Return a dict mapping metric names to current value
227-
logs = {m.name: m.result() for m in self.metrics}
234+
logs = self._update_logs(ys, ys_pred)
228235
return logs
236+
237+
238+
def compile(self,
239+
optimizer=None,
240+
loss=None,
241+
metrics=None,
242+
loss_weights=None,
243+
weighted_metrics=None,
244+
run_eagerly=None,
245+
steps_per_execution=None,
246+
**kwargs):
247+
super().compile(optimizer=optimizer,
248+
loss=loss,
249+
metrics=metrics,
250+
loss_weights=loss_weights,
251+
weighted_metrics=weighted_metrics,
252+
run_eagerly=run_eagerly,
253+
steps_per_execution=steps_per_execution,
254+
**kwargs)
255+
self.optimizer_weight = self.optimizer.__class__.from_config(self.optimizer.get_config())
256+
self.optimizer_disc = self.optimizer.__class__.from_config(self.optimizer.get_config())
229257

230258

231259
def predict_weights(self, X):

adapt/parameter_based/_finetuning.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,7 @@ def pretrain_step(self, data):
146146
self.optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
147147

148148
# Update metrics
149-
self.compiled_metrics.update_state(ys, ys_pred)
150-
self.compiled_loss(ys, ys_pred)
151-
# Return a dict mapping metric names to current value
152-
logs = {m.name: m.result() for m in self.metrics}
149+
logs = self._update_logs(ys, ys_pred)
153150
return logs
154151

155152

@@ -185,13 +182,11 @@ def train_step(self, data):
185182

186183
# Update weights
187184
self.optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
188-
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
185+
if len(trainable_vars_enc) > 0:
186+
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
189187

190188
# Update metrics
191-
self.compiled_metrics.update_state(ys, ys_pred)
192-
self.compiled_loss(ys, ys_pred)
193-
# Return a dict mapping metric names to current value
194-
logs = {m.name: m.result() for m in self.metrics}
189+
logs = self._update_logs(ys, ys_pred)
195190
return logs
196191

197192

adapt/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def check_if_compiled(network):
590590
"""
591591
if hasattr(network, "compiled") and network.compiled:
592592
return True
593-
elif hasattr(network, "_is_compiled") and networtf._is_compiled:
593+
elif hasattr(network, "_is_compiled") and network._is_compiled:
594594
return True
595595
else:
596596
return False

tests/test_finetuning.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
from adapt.utils import make_classification_da
66
from adapt.parameter_based import FineTuning
77
from tensorflow.keras.initializers import GlorotUniform
8-
try:
9-
from tensorflow.keras.optimizers.legacy import Adam
10-
except:
11-
from tensorflow.keras.optimizers import Adam
8+
from tensorflow.keras.optimizers import Adam
129

1310
np.random.seed(0)
1411
tf.random.set_seed(0)
@@ -44,7 +41,7 @@ def test_finetune():
4441
loss="bce", optimizer=Adam(), random_state=0)
4542
fine_tuned.fit(Xt[ind], yt[ind], epochs=100, verbose=0)
4643

47-
assert np.abs(fine_tuned.encoder_.get_weights()[0] - model.encoder_.get_weights()[0]).sum() > 1.
44+
assert np.abs(fine_tuned.encoder_.get_weights()[0] - model.encoder_.get_weights()[0]).sum() > 0.5
4845
assert np.mean((fine_tuned.predict(Xt).ravel()>0.5) == yt) > 0.9
4946

5047
fine_tuned = FineTuning(encoder=model.encoder_, task=model.task_,
@@ -53,7 +50,7 @@ def test_finetune():
5350
fine_tuned.fit(Xt[ind], yt[ind], epochs=100, verbose=0)
5451

5552
assert np.abs(fine_tuned.encoder_.get_weights()[0] - model.encoder_.get_weights()[0]).sum() == 0.
56-
assert np.abs(fine_tuned.encoder_.get_weights()[-1] - model.encoder_.get_weights()[-1]).sum() > 1.
53+
assert np.abs(fine_tuned.encoder_.get_weights()[-1] - model.encoder_.get_weights()[-1]).sum() > .5
5754

5855
fine_tuned = FineTuning(encoder=model.encoder_, task=model.task_,
5956
training=[False],

tests/test_iwc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from adapt.utils import make_classification_da
88
from adapt.instance_based import IWC
99
from adapt.utils import get_default_discriminator
10-
try:
11-
from tensorflow.keras.optimizers.legacy import Adam
12-
except:
13-
from tensorflow.keras.optimizers import Adam
10+
from tensorflow.keras.optimizers import Adam
1411

1512
Xs, ys, Xt, yt = make_classification_da()
1613

tests/test_iwn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from adapt.instance_based import IWN
88
from adapt.utils import get_default_task
99
from sklearn.neighbors import KNeighborsClassifier
10-
try:
11-
from tensorflow.keras.optimizers.legacy import Adam
12-
except:
13-
from tensorflow.keras.optimizers import Adam
10+
from tensorflow.keras.optimizers import Adam
1411

1512
Xs, ys, Xt, yt = make_classification_da()
1613

0 commit comments

Comments
 (0)