55import torch
66import torch .nn .functional as F
77
8+ from pytorch_adapt .adapters import ADDA
9+ from pytorch_adapt .containers import Models , Optimizers
810from pytorch_adapt .hooks import ADDAHook , BSPHook , validate_hook
911from pytorch_adapt .utils import common_functions as c_f
1012
1113from .utils import (
1214 Net ,
15+ assert_equal_models ,
1316 assertRequiresGrad ,
17+ get_opt_tuple ,
1418 get_opts ,
1519 post_g_hook_update_keys ,
1620 post_g_hook_update_total_loss ,
1721)
1822
1923
24+ def test_equivalent_adapter (G , D , data , post_g , threshold ):
25+ models = Models (
26+ {"G" : copy .deepcopy (G ), "D" : copy .deepcopy (D ), "C" : torch .nn .Identity ()}
27+ )
28+ optimizers = Optimizers (get_opt_tuple ())
29+ adapter = ADDA (
30+ models , optimizers , hook_kwargs = {"post_g" : post_g , "threshold" : threshold }
31+ )
32+ adapter .training_step (data )
33+ return models
34+
35+
2036def get_models_and_data ():
2137 src_domain = torch .randint (0 , 2 , size = (100 ,)).float ()
2238 target_domain = torch .randint (0 , 2 , size = (100 ,)).float ()
@@ -30,7 +46,7 @@ def get_models_and_data():
3046class TestADDA (unittest .TestCase ):
3147 def test_adda (self ):
3248 torch .manual_seed (922 )
33- for post_g in [None , BSPHook (domains = ["target" ])]:
49+ for post_g in [None , [ BSPHook (domains = ["target" ])] ]:
3450 for threshold in np .linspace (0 , 1 , 10 ):
3551 (
3652 G ,
@@ -46,9 +62,8 @@ def test_adda(self):
4662 originalT = copy .deepcopy (T )
4763 d_opts = get_opts (D )
4864 g_opts = get_opts (T )
49- post_g_ = [post_g ] if post_g is not None else post_g
5065 h = ADDAHook (
51- d_opts = d_opts , g_opts = g_opts , threshold = threshold , post_g = post_g_
66+ d_opts = d_opts , g_opts = g_opts , threshold = threshold , post_g = post_g
5267 )
5368 models = {"G" : G , "D" : D , "T" : T }
5469 data = {
@@ -78,6 +93,10 @@ def test_adda(self):
7893 )
7994 self .assertTrue (losses ["g_loss" ].keys () == g_loss_keys )
8095
96+ adapter_models = test_equivalent_adapter (
97+ originalG , originalD , data , post_g , threshold
98+ )
99+
81100 d_opts = get_opts (originalD )[0 ]
82101 g_opts = get_opts (originalT )[0 ]
83102 originalG .eval ()
@@ -142,9 +161,12 @@ def test_adda(self):
142161 # can't use model_counts for conditional part
143162 self .assertTrue (D .count == d_count )
144163
145- for x , y in [(G , originalG ), (T , originalT ), (D , originalD )]:
146- self .assertTrue (
147- c_f .state_dicts_are_equal (
148- x .state_dict (), y .state_dict (), rtol = 1e-3
149- )
150- )
164+ assert_equal_models (
165+ self , (G , adapter_models ["G" ], originalG ), rtol = 1e-3
166+ )
167+ assert_equal_models (
168+ self , (T , adapter_models ["T" ], originalT ), rtol = 1e-3
169+ )
170+ assert_equal_models (
171+ self , (D , adapter_models ["D" ], originalD ), rtol = 1e-3
172+ )
0 commit comments