Skip to content

Commit 3cef36f

Browse files
Merge pull request #97 from KevinMusgrave/dev
v0.0.82
2 parents 073c20e + 7227732 commit 3cef36f

File tree

18 files changed

+134
-75
lines changed

18 files changed

+134
-75
lines changed

.flake8

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
[flake8]
22

33
extend-ignore =
4-
E266 # too many leading '#' for block comment
5-
E203 # whitespace before ':'
6-
E402 # module level import not at top of file
7-
E501 # line too long
8-
4+
# too many leading '#' for block comment
5+
E266
6+
# whitespace before ':'
7+
E203
8+
# module level import not at top of file
9+
E402
10+
# line too long
11+
E501
912
per-file-ignores =
1013
__init__.py:F401

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
extras_require_detection = ["albumentations >= 1.2.1"]
1313
extras_require_ignite = ["pytorch-ignite == 0.4.9"]
1414
extras_require_lightning = ["pytorch-lightning"]
15-
extras_require_record_keeper = ["record-keeper >= 0.9.32"]
15+
extras_require_record_keeper = ["record-keeper >= 0.9.32", "tensorboard"]
1616
extras_require_timm = ["timm"]
1717
extras_require_docs = [
1818
"mkdocs-material",
@@ -44,8 +44,8 @@
4444
"numpy",
4545
"torch",
4646
"torchvision",
47-
"torchmetrics >= 0.9.3",
48-
"pytorch-metric-learning >= 1.5.2",
47+
"torchmetrics == 0.9.3",
48+
"pytorch-metric-learning >= 1.6.3",
4949
],
5050
extras_require={
5151
"detection": extras_require_detection,

src/pytorch_adapt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.81"
1+
__version__ = "0.0.82"

src/pytorch_adapt/hooks/classification.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
5858
detach_features: bool = False,
5959
f_hook: BaseHook = None,
60+
domains=("src",),
6061
**kwargs,
6162
):
6263
"""
@@ -75,21 +76,28 @@ def __init__(
7576
self.hook = c_f.default(
7677
f_hook,
7778
FeaturesAndLogitsHook,
78-
{"domains": ["src"], "detach_features": detach_features},
79+
{"domains": domains, "detach_features": detach_features},
7980
)
8081

8182
def call(self, inputs, losses):
8283
""""""
8384
outputs = self.hook(inputs, losses)[0]
84-
[src_logits] = c_f.extract(
85-
[outputs, inputs], c_f.filter(self.hook.out_keys, "_logits$")
85+
output_losses = {}
86+
logits = c_f.extract(
87+
[outputs, inputs],
88+
c_f.filter(
89+
self.hook.out_keys, "_logits$", [f"^{d}" for d in self.hook.domains]
90+
),
8691
)
87-
loss = self.loss_fn(src_logits, inputs["src_labels"])
88-
return outputs, {self._loss_keys()[0]: loss}
92+
for i, d in enumerate(self.hook.domains):
93+
output_losses[self._loss_keys()[i]] = self.loss_fn(
94+
logits[i], inputs[f"{d}_labels"]
95+
)
96+
return outputs, output_losses
8997

9098
def _loss_keys(self):
9199
""""""
92-
return ["c_loss"]
100+
return [f"{d}_c_loss" for d in self.hook.domains]
93101

94102

95103
class ClassifierHook(BaseWrapperHook):

src/pytorch_adapt/hooks/features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def __init__(
315315
):
316316
for i in range(len(hooks) - 1):
317317
hooks[i + 1].set_in_keys(hooks[i].out_keys)
318+
self.domains = hooks[-1].domains
318319
super().__init__(*hooks, **kwargs)
319320

320321

@@ -326,7 +327,7 @@ class FeaturesAndLogitsHook(FeaturesChainHook):
326327

327328
def __init__(
328329
self,
329-
domains: List[str] = None,
330+
domains: List[str] = ("src", "target"),
330331
detach_features: bool = False,
331332
detach_logits: bool = False,
332333
other_hooks: List[BaseHook] = None,
@@ -343,6 +344,7 @@ def __init__(
343344
other_hooks: A list of hooks that will be called after
344345
the features and logits hooks.
345346
"""
347+
self.domains = domains
346348
features_hook = FeaturesHook(detach=detach_features, domains=domains)
347349
logits_hook = LogitsHook(detach=detach_logits, domains=domains)
348350
other_hooks = c_f.default(other_hooks, [])

tests/adapters/run_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# log files should be a mapping from csv file name, to number of columns in file
2929
def run_adapter(cls, test_folder, adapter, log_files=None, inference_fn=None):
30-
checkpoint_fn = CheckpointFnCreator(dirname=test_folder)
30+
checkpoint_fn = CheckpointFnCreator(dirname=test_folder, require_empty=False)
3131
logger = IgniteRecordKeeperLogger(folder=test_folder)
3232
datasets = get_datasets()
3333
validator = ScoreHistory(EntropyValidator())

tests/adapters/test_running.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def gan_log_files():
9595
"total",
9696
"g_src_domain_loss",
9797
"g_target_domain_loss",
98-
"c_loss",
98+
"src_c_loss",
9999
},
100100
"engine_output_d_loss": {
101101
"total",
@@ -153,7 +153,7 @@ def test_aligner(self):
153153
"optimizers_C_Adam": {"lr"},
154154
"engine_output_total_loss": {
155155
"total",
156-
"c_loss",
156+
"src_c_loss",
157157
"features_confusion_loss",
158158
"logits_confusion_loss",
159159
},
@@ -172,7 +172,9 @@ def test_aligner(self):
172172
def test_cdan(self):
173173
models = get_gcd()
174174
misc = Misc({"feature_combiner": RandomizedDotProduct([512, 10], 512)})
175-
g_weighter = MeanWeighter(weights={"g_target_domain_loss": 0.5, "c_loss": 0.1})
175+
g_weighter = MeanWeighter(
176+
weights={"g_target_domain_loss": 0.5, "src_c_loss": 0.1}
177+
)
176178
adapter = CDAN(models=models, misc=misc, hook_kwargs={"g_weighter": g_weighter})
177179
self.assertTrue(isinstance(adapter.hook, CDANHook))
178180
log_files = gan_log_files()
@@ -186,7 +188,7 @@ def test_cdan(self):
186188
},
187189
"hook_8c2a74151317b9315573314fafc0d8ad6e12f72a84433739f6f0762a4ca11ab0_weights": {
188190
"g_target_domain_loss",
189-
"c_loss",
191+
"src_c_loss",
190192
},
191193
}
192194
)
@@ -204,7 +206,7 @@ def test_classifier(self):
204206
"optimizers_C_Adam": {"lr"},
205207
"engine_output_total_loss": {
206208
"total",
207-
"c_loss",
209+
"src_c_loss",
208210
},
209211
"hook_ClassifierHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
210212
"scale"
@@ -224,7 +226,7 @@ def test_dann(self):
224226
"optimizers_D_Adam": {"lr"},
225227
"engine_output_total_loss": {
226228
"total",
227-
"c_loss",
229+
"src_c_loss",
228230
"src_domain_loss",
229231
"target_domain_loss",
230232
},
@@ -267,7 +269,7 @@ def test_finetuner(self):
267269
"optimizers_C_Adam": {"lr"},
268270
"engine_output_total_loss": {
269271
"total",
270-
"c_loss",
272+
"src_c_loss",
271273
},
272274
"hook_FinetunerHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
273275
"scale"
@@ -305,7 +307,7 @@ def test_joint_aligner(self):
305307
"optimizers_C_Adam": {"lr"},
306308
"engine_output_total_loss": {
307309
"total",
308-
"c_loss",
310+
"src_c_loss",
309311
"joint_confusion_loss",
310312
},
311313
"hook_AlignerPlusCHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
@@ -328,7 +330,7 @@ def test_gvb(self):
328330
"optimizers_D_Adam": {"lr"},
329331
"engine_output_total_loss": {
330332
"total",
331-
"c_loss",
333+
"src_c_loss",
332334
"src_domain_loss",
333335
"target_domain_loss",
334336
"g_src_bridge_loss",
@@ -361,13 +363,13 @@ def test_mcd(self):
361363
"optimizers_C_Adam": {"lr"},
362364
"engine_output_x_loss": {
363365
"total",
364-
"c_loss0",
365-
"c_loss1",
366+
"src_c_loss0",
367+
"src_c_loss1",
366368
},
367369
"engine_output_y_loss": {
368370
"total",
369-
"c_loss0",
370-
"c_loss1",
371+
"src_c_loss0",
372+
"src_c_loss1",
371373
"discrepancy_loss",
372374
},
373375
"engine_output_z_loss": {"total", "discrepancy_loss"},
@@ -399,7 +401,7 @@ def test_rtn(self):
399401
"optimizers_residual_model_Adam": {"lr"},
400402
"engine_output_total_loss": {
401403
"total",
402-
"c_loss",
404+
"src_c_loss",
403405
"entropy_loss",
404406
"features_confusion_loss",
405407
},
@@ -424,8 +426,8 @@ def test_symnets(self):
424426
"optimizers_G_Adam": {"lr"},
425427
"optimizers_C_Adam": {"lr"},
426428
"engine_output_c_loss": {
427-
"c_loss0",
428-
"c_loss1",
429+
"src_c_loss0",
430+
"src_c_loss1",
429431
"c_symnets_src_domain_loss_0",
430432
"c_symnets_target_domain_loss_1",
431433
"total",

tests/hooks/test_aligners.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_aligner_plus_classifier_hook(self):
7575
)
7676

7777
loss_keys = {
78-
"c_loss",
78+
"src_c_loss",
7979
"total",
8080
}
8181

@@ -120,15 +120,15 @@ def test_aligner_plus_classifier_hook(self):
120120
[F.softmax(target_logits, dim=1), target_features],
121121
)
122122
total_loss = (f_loss + c_loss) / 2
123-
correct_losses = [c_loss, f_loss, total_loss]
123+
correct_losses = [f_loss, c_loss, total_loss]
124124
else:
125125
f_loss = loss_fn()(src_features, target_features)
126126
l_loss = loss_fn()(
127127
F.softmax(src_logits, dim=1), F.softmax(target_logits, dim=1)
128128
)
129129

130130
total_loss = (f_loss + l_loss + c_loss) / 3
131-
correct_losses = [c_loss, f_loss, l_loss, total_loss]
131+
correct_losses = [f_loss, l_loss, c_loss, total_loss]
132132

133133
computed_losses = [
134134
losses["total_loss"][k] for k in sorted(list(loss_keys))

tests/hooks/test_cdan.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_cdan_hook(self):
250250
g_loss_keys = {
251251
"g_src_domain_loss",
252252
"g_target_domain_loss",
253-
"c_loss",
253+
"src_c_loss",
254254
"total",
255255
}
256256

@@ -349,14 +349,18 @@ def test_cdan_hook(self):
349349
g_losses["g_target_domain_loss"] * target_entropy_weights
350350
)
351351

352-
g_losses["c_loss"] = torch.nn.functional.cross_entropy(
352+
g_losses["src_c_loss"] = torch.nn.functional.cross_entropy(
353353
c_logits[:bs], src_labels
354354
)
355355

356356
self.assertTrue(
357357
all(
358358
np.isclose(losses["g_loss"][k], g_losses[k].item())
359-
for k in ["g_src_domain_loss", "g_target_domain_loss", "c_loss"]
359+
for k in [
360+
"g_src_domain_loss",
361+
"g_target_domain_loss",
362+
"src_c_loss",
363+
]
360364
)
361365
)
362366
g_losses = list(g_losses.values())

tests/hooks/test_classification.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,53 @@ def test_softmax_hook(self):
3030

3131
def test_closs_hook(self):
3232
torch.manual_seed(24242)
33+
src_imgs = torch.randn(100, 32)
34+
target_imgs = torch.randn(100, 32)
35+
src_labels = torch.randint(0, 10, size=(100,))
36+
target_labels = torch.randint(0, 10, size=(100,))
37+
G = Net(32, 16)
38+
C = Net(16, 10)
39+
3340
for detach_features in [True, False]:
34-
h = CLossHook(detach_features=detach_features)
35-
src_imgs = torch.randn(100, 32)
36-
target_imgs = torch.randn(100, 32)
37-
src_labels = torch.randint(0, 10, size=(100,))
38-
G = Net(32, 16)
39-
C = Net(16, 10)
40-
outputs, losses = h(locals())
41-
assertRequiresGrad(self, outputs)
42-
base_key = "src_imgs_features"
43-
if detach_features:
44-
base_key += "_detached"
45-
self.assertTrue(outputs.keys() == {base_key, f"{base_key}_logits"})
41+
for domains in [None, ("src",), ("target",), ("src", "target")]:
42+
if domains is None:
43+
h = CLossHook(detach_features=detach_features)
44+
else:
45+
h = CLossHook(detach_features=detach_features, domains=domains)
46+
outputs, losses = h(locals())
47+
assertRequiresGrad(self, outputs)
48+
base_keys = (
49+
[f"{d}_imgs_features" for d in domains]
50+
if domains
51+
else ["src_imgs_features"]
52+
)
53+
if detach_features:
54+
base_keys = [f"{x}_detached" for x in base_keys]
55+
logit_keys = [f"{x}_logits" for x in base_keys]
56+
self.assertTrue(outputs.keys() == {*base_keys, *logit_keys})
57+
58+
correct_loss_fn = torch.nn.functional.cross_entropy
59+
for k, v in losses.items():
60+
if k.startswith("src"):
61+
self.assertTrue(
62+
torch.equal(
63+
v,
64+
correct_loss_fn(
65+
C(G(src_imgs)), src_labels, reduction="none"
66+
),
67+
)
68+
)
69+
elif k.startswith("target"):
70+
self.assertTrue(
71+
torch.equal(
72+
v,
73+
correct_loss_fn(
74+
C(G(target_imgs)), target_labels, reduction="none"
75+
),
76+
)
77+
)
78+
else:
79+
raise KeyError
4680

4781
def test_classifier_hook(self):
4882
torch.manual_seed(53430)
@@ -77,4 +111,4 @@ def test_classifier_hook(self):
77111
_, losses = h(
78112
{"G": G, "C": C, "src_imgs": src_imgs, "src_labels": src_labels}
79113
)
80-
self.assertTrue(np.isclose(losses["total_loss"]["c_loss"], correct))
114+
self.assertTrue(np.isclose(losses["total_loss"]["src_c_loss"], correct))

0 commit comments

Comments
 (0)