Skip to content

Commit 09cf3b0

Browse files
Merge pull request #49 from KevinMusgrave/dev
Added test_equivalent_adapter to hook tests
2 parents 8a9c672 + 6f2720d commit 09cf3b0

File tree

16 files changed

+367
-104
lines changed

16 files changed

+367
-104
lines changed

src/pytorch_adapt/adapters/adabn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
class AdaBN(BaseAdapter):
99
hook_cls = AdaBNHook
1010

11-
def __init__(self, inference_fn=None, **kwargs):
11+
def __init__(self, *args, inference_fn=None, **kwargs):
1212
inference_fn = c_f.default(inference_fn, adabn_fn)
13-
super().__init__(inference_fn=inference_fn, **kwargs)
13+
super().__init__(*args, inference_fn=inference_fn, **kwargs)
1414

1515
def init_hook(self, hook_kwargs):
1616
self.hook = self.hook_cls(**hook_kwargs)

src/pytorch_adapt/adapters/adda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class ADDA(BaseAdapter):
2323

2424
hook_cls = ADDAHook
2525

26-
def __init__(self, inference_fn=None, **kwargs):
26+
def __init__(self, *args, inference_fn=None, **kwargs):
2727
inference_fn = c_f.default(inference_fn, adda_fn)
28-
super().__init__(inference_fn=inference_fn, **kwargs)
28+
super().__init__(*args, inference_fn=inference_fn, **kwargs)
2929

3030
def get_default_containers(self) -> MultipleContainers:
3131
"""

src/pytorch_adapt/adapters/aligner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ class RTN(Aligner):
3838

3939
hook_cls = RTNHook
4040

41-
def __init__(self, inference_fn=None, **kwargs):
41+
def __init__(self, *args, inference_fn=None, **kwargs):
4242
inference_fn = c_f.default(inference_fn, rtn_fn)
43-
super().__init__(inference_fn=inference_fn, **kwargs)
43+
super().__init__(*args, inference_fn=inference_fn, **kwargs)
4444

4545
def get_key_enforcer(self) -> KeyEnforcer:
4646
ke = super().get_key_enforcer()

src/pytorch_adapt/adapters/mcd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ class MCD(BaseGCAdapter):
2222

2323
hook_cls = MCDHook
2424

25-
def __init__(self, inference_fn=None, **kwargs):
25+
def __init__(self, *args, inference_fn=None, **kwargs):
2626
inference_fn = c_f.default(inference_fn, mcd_fn)
27-
super().__init__(inference_fn=inference_fn, **kwargs)
27+
super().__init__(*args, inference_fn=inference_fn, **kwargs)
2828

2929
def init_hook(self, hook_kwargs):
3030
self.hook = self.hook_cls(

src/pytorch_adapt/adapters/symnets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class SymNets(BaseGCAdapter):
2020

2121
hook_cls = SymNetsHook
2222

23-
def __init__(self, inference_fn=None, **kwargs):
23+
def __init__(self, *args, inference_fn=None, **kwargs):
2424
inference_fn = c_f.default(inference_fn, symnets_fn)
25-
super().__init__(inference_fn=inference_fn, **kwargs)
25+
super().__init__(*args, inference_fn=inference_fn, **kwargs)
2626

2727
def init_hook(self, hook_kwargs):
2828
self.hook = self.hook_cls(

tests/hooks/test_adabn.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33

44
import torch
55

6+
from pytorch_adapt.adapters import AdaBN
7+
from pytorch_adapt.containers import Models
68
from pytorch_adapt.hooks import AdaBNHook, validate_hook
79
from pytorch_adapt.layers import AdaptiveBatchNorm2d
810
from pytorch_adapt.layers.adaptive_batch_norm import set_curr_domain
911

1012
from .utils import assertRequiresGrad
1113

1214

15+
def test_equivalent_adapter(G, C, data):
16+
models = Models({"G": copy.deepcopy(G), "C": copy.deepcopy(C)})
17+
adapter = AdaBN(models)
18+
adapter.training_step(data)
19+
return models
20+
21+
1322
class Net(torch.nn.Module):
1423
def __init__(self, in_size, out_size):
1524
super().__init__()
@@ -53,17 +62,22 @@ def test_adabn_hook(self):
5362
assertRequiresGrad(self, outputs)
5463
self.assertTrue(len(losses) == 0)
5564

65+
adapter_models = test_equivalent_adapter(originalG, originalC, data)
66+
5667
originalG.net[0].bns[0](src_imgs)
5768
originalG.net[0].bns[1](target_imgs)
5869

5970
for i in range(2):
60-
self.assertTrue(
61-
torch.equal(
62-
G.net[0].bns[i].running_mean, originalG.net[0].bns[i].running_mean
71+
for M in [models, adapter_models]:
72+
self.assertTrue(
73+
torch.equal(
74+
M["G"].net[0].bns[i].running_mean,
75+
originalG.net[0].bns[i].running_mean,
76+
)
6377
)
64-
)
65-
self.assertTrue(
66-
torch.equal(
67-
G.net[0].bns[i].running_var, originalG.net[0].bns[i].running_var
78+
self.assertTrue(
79+
torch.equal(
80+
M["G"].net[0].bns[i].running_var,
81+
originalG.net[0].bns[i].running_var,
82+
)
6883
)
69-
)

tests/hooks/test_adda.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,34 @@
55
import torch
66
import torch.nn.functional as F
77

8+
from pytorch_adapt.adapters import ADDA
9+
from pytorch_adapt.containers import Models, Optimizers
810
from pytorch_adapt.hooks import ADDAHook, BSPHook, validate_hook
911
from pytorch_adapt.utils import common_functions as c_f
1012

1113
from .utils import (
1214
Net,
15+
assert_equal_models,
1316
assertRequiresGrad,
17+
get_opt_tuple,
1418
get_opts,
1519
post_g_hook_update_keys,
1620
post_g_hook_update_total_loss,
1721
)
1822

1923

24+
def test_equivalent_adapter(G, D, data, post_g, threshold):
25+
models = Models(
26+
{"G": copy.deepcopy(G), "D": copy.deepcopy(D), "C": torch.nn.Identity()}
27+
)
28+
optimizers = Optimizers(get_opt_tuple())
29+
adapter = ADDA(
30+
models, optimizers, hook_kwargs={"post_g": post_g, "threshold": threshold}
31+
)
32+
adapter.training_step(data)
33+
return models
34+
35+
2036
def get_models_and_data():
2137
src_domain = torch.randint(0, 2, size=(100,)).float()
2238
target_domain = torch.randint(0, 2, size=(100,)).float()
@@ -30,7 +46,7 @@ def get_models_and_data():
3046
class TestADDA(unittest.TestCase):
3147
def test_adda(self):
3248
torch.manual_seed(922)
33-
for post_g in [None, BSPHook(domains=["target"])]:
49+
for post_g in [None, [BSPHook(domains=["target"])]]:
3450
for threshold in np.linspace(0, 1, 10):
3551
(
3652
G,
@@ -46,9 +62,8 @@ def test_adda(self):
4662
originalT = copy.deepcopy(T)
4763
d_opts = get_opts(D)
4864
g_opts = get_opts(T)
49-
post_g_ = [post_g] if post_g is not None else post_g
5065
h = ADDAHook(
51-
d_opts=d_opts, g_opts=g_opts, threshold=threshold, post_g=post_g_
66+
d_opts=d_opts, g_opts=g_opts, threshold=threshold, post_g=post_g
5267
)
5368
models = {"G": G, "D": D, "T": T}
5469
data = {
@@ -78,6 +93,10 @@ def test_adda(self):
7893
)
7994
self.assertTrue(losses["g_loss"].keys() == g_loss_keys)
8095

96+
adapter_models = test_equivalent_adapter(
97+
originalG, originalD, data, post_g, threshold
98+
)
99+
81100
d_opts = get_opts(originalD)[0]
82101
g_opts = get_opts(originalT)[0]
83102
originalG.eval()
@@ -142,9 +161,12 @@ def test_adda(self):
142161
# can't use model_counts for conditional part
143162
self.assertTrue(D.count == d_count)
144163

145-
for x, y in [(G, originalG), (T, originalT), (D, originalD)]:
146-
self.assertTrue(
147-
c_f.state_dicts_are_equal(
148-
x.state_dict(), y.state_dict(), rtol=1e-3
149-
)
150-
)
164+
assert_equal_models(
165+
self, (G, adapter_models["G"], originalG), rtol=1e-3
166+
)
167+
assert_equal_models(
168+
self, (T, adapter_models["T"], originalT), rtol=1e-3
169+
)
170+
assert_equal_models(
171+
self, (D, adapter_models["D"], originalD), rtol=1e-3
172+
)

tests/hooks/test_aligners.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,30 @@
55
import torch
66
import torch.nn.functional as F
77

8+
from pytorch_adapt.adapters import Aligner
9+
from pytorch_adapt.containers import Models, Optimizers
810
from pytorch_adapt.hooks import AlignerPlusCHook, JointAlignerHook, validate_hook
911
from pytorch_adapt.layers import CORALLoss, MMDLoss
10-
from pytorch_adapt.utils import common_functions as c_f
1112

12-
from .utils import assertRequiresGrad, get_models_and_data, get_opts
13+
from .utils import (
14+
assert_equal_models,
15+
assertRequiresGrad,
16+
get_models_and_data,
17+
get_opt_tuple,
18+
get_opts,
19+
)
20+
21+
22+
def test_equivalent_adapter(G, C, data, aligner_hook, loss_fn):
23+
models = Models({"G": copy.deepcopy(G), "C": copy.deepcopy(C)})
24+
optimizers = Optimizers(get_opt_tuple())
25+
adapter = Aligner(
26+
models,
27+
optimizers,
28+
hook_kwargs={"aligner_hook": aligner_hook, "loss_fn": loss_fn},
29+
)
30+
adapter.training_step(data)
31+
return models
1332

1433

1534
class TestAligners(unittest.TestCase):
@@ -74,6 +93,10 @@ def test_aligner_plus_classifier_hook(self):
7493
G.count == model_counts["G"] == C.count == model_counts["C"] == 2
7594
)
7695

96+
adapter_models = test_equivalent_adapter(
97+
originalG, originalC, data, aligner_hook, loss_fn()
98+
)
99+
77100
opts = get_opts(originalG, originalC)
78101

79102
src_features = originalG(src_imgs)
@@ -120,9 +143,9 @@ def test_aligner_plus_classifier_hook(self):
120143
total_loss.backward()
121144
[x.step() for x in opts]
122145

123-
for x, y in [(G, originalG), (C, originalC)]:
124-
self.assertTrue(
125-
c_f.state_dicts_are_equal(
126-
x.state_dict(), y.state_dict(), rtol=1e-6
127-
)
128-
)
146+
assert_equal_models(
147+
self, (G, adapter_models["G"], originalG), rtol=1e-6
148+
)
149+
assert_equal_models(
150+
self, (C, adapter_models["C"], originalC), rtol=1e-6
151+
)

tests/hooks/test_cdan.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import torch
77

8+
from pytorch_adapt.adapters import CDANE
9+
from pytorch_adapt.containers import Misc, Models, Optimizers
810
from pytorch_adapt.hooks import (
911
AFNHook,
1012
BNMHook,
@@ -16,18 +18,41 @@
1618
validate_hook,
1719
)
1820
from pytorch_adapt.layers import RandomizedDotProduct
19-
from pytorch_adapt.utils import common_functions as c_f
2021

2122
from .utils import (
23+
assert_equal_models,
2224
assertRequiresGrad,
2325
get_entropy_weights,
2426
get_models_and_data,
27+
get_opt_tuple,
2528
get_opts,
2629
post_g_hook_update_keys,
2730
post_g_hook_update_total_loss,
2831
)
2932

3033

34+
def test_equivalent_adapter(
35+
G, D, C, feature_combiner, data, detach_reducer, post_g, softmax
36+
):
37+
models = Models(
38+
{"G": copy.deepcopy(G), "D": copy.deepcopy(D), "C": copy.deepcopy(C)}
39+
)
40+
misc = Misc({"feature_combiner": copy.deepcopy(feature_combiner)})
41+
optimizers = Optimizers(get_opt_tuple())
42+
adapter = CDANE(
43+
models,
44+
optimizers,
45+
misc=misc,
46+
hook_kwargs={
47+
"detach_entropy_reducer": detach_reducer,
48+
"post_g": post_g,
49+
"softmax": softmax,
50+
},
51+
)
52+
adapter.training_step(data)
53+
return models
54+
55+
3156
def get_correct_domain_losses(
3257
G,
3358
C,
@@ -159,7 +184,7 @@ def test_cdan_domain_hooks(self):
159184
def test_cdan_hook(self):
160185
torch.manual_seed(985)
161186
for detach_reducer in [False, True]:
162-
for post_g in [None, BSPHook(), BNMHook(), MCCHook(), AFNHook()]:
187+
for post_g in [None, [BSPHook()], [BNMHook()], [MCCHook()], [AFNHook()]]:
163188
softmax = True
164189
fc_out_size = 16
165190
(
@@ -189,13 +214,12 @@ def test_cdan_hook(self):
189214
"src_domain": src_domain,
190215
"target_domain": target_domain,
191216
}
192-
post_g_ = [post_g] if post_g is not None else post_g
193217
hook = CDANEHook(
194218
detach_entropy_reducer=detach_reducer,
195219
d_opts=d_opts,
196220
g_opts=g_opts,
197221
softmax=softmax,
198-
post_g=post_g_,
222+
post_g=post_g,
199223
)
200224
model_counts = validate_hook(hook, list(data.keys()))
201225
outputs, losses = hook({**models, **data})
@@ -236,6 +260,17 @@ def test_cdan_hook(self):
236260
)
237261
self.assertTrue(D.count == model_counts["D"] == 4)
238262

263+
adapter_models = test_equivalent_adapter(
264+
originalG,
265+
originalD,
266+
originalC,
267+
originalFeatureCombiner,
268+
data,
269+
detach_reducer,
270+
post_g,
271+
softmax,
272+
)
273+
239274
d_opts = get_opts(originalD)
240275
g_opts = get_opts(originalG, originalC)
241276

@@ -338,9 +373,12 @@ def test_cdan_hook(self):
338373
total_loss.backward()
339374
[x.step() for x in g_opts]
340375

341-
for x, y in [(G, originalG), (C, originalC), (D, originalD)]:
342-
self.assertTrue(
343-
c_f.state_dicts_are_equal(
344-
x.state_dict(), y.state_dict(), rtol=1e-2
345-
)
346-
)
376+
assert_equal_models(
377+
self, (G, adapter_models["G"], originalG), rtol=1e-2
378+
)
379+
assert_equal_models(
380+
self, (C, adapter_models["C"], originalC), rtol=1e-2
381+
)
382+
assert_equal_models(
383+
self, (D, adapter_models["D"], originalD), rtol=1e-2
384+
)

0 commit comments

Comments
 (0)