Skip to content

Commit 2ef1b96

Browse files
Merge pull request #52 from KevinMusgrave/dev
v0.0.61
2 parents 09cf3b0 + 6dac330 commit 2ef1b96

File tree

10 files changed

+164
-32
lines changed

10 files changed

+164
-32
lines changed

src/pytorch_adapt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.60"
1+
__version__ = "0.0.61"

src/pytorch_adapt/hooks/adabn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def add_if_new(
1818
inputs,
1919
model_name,
2020
in_keys,
21-
other_args=[domain],
21+
other_args={"domain": domain},
22+
logger=self.logger,
2223
)
2324

2425

src/pytorch_adapt/hooks/base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from ..utils import common_functions as c_f
8+
from .logger import HookLogger
89

910

1011
class BaseHook(ABC):
@@ -41,13 +42,16 @@ def __init__(
4142
self.out_suffix = out_suffix
4243
self.key_map = c_f.default(key_map, {})
4344
self.in_keys = []
45+
self.logger = HookLogger(c_f.cls_name(self))
4446

4547
def __call__(self, inputs, losses=None):
48+
self.logger("__call__")
4649
losses = c_f.default(losses, {})
4750
try:
4851
inputs = c_f.map_keys(inputs, self.key_map)
4952
x = self.call(inputs, losses)
5053
if isinstance(x, (bool, np.bool_)):
54+
self.logger.reset()
5155
return x
5256
elif isinstance(x, tuple):
5357
outputs, losses = x
@@ -56,14 +60,15 @@ def __call__(self, inputs, losses=None):
5660
outputs = wrap_keys(outputs, self.out_prefix, self.out_suffix)
5761
losses = wrap_keys(losses, self.loss_prefix, self.loss_suffix)
5862
self.check_losses_and_outputs(outputs, losses, inputs)
63+
self.logger.reset()
5964
return outputs, losses
6065
else:
6166
raise TypeError(
6267
f"Output is of type {type(x)}, but should be bool or tuple"
6368
)
6469
except Exception as e:
65-
if not isinstance(e, KeyError):
66-
c_f.append_error_message(e, self.str_for_error_msg(n=1))
70+
c_f.add_error_message(e, f"in {self.logger.str}\n", prepend=True)
71+
self.logger.reset()
6772
raise
6873

6974
@abstractmethod
@@ -124,13 +129,6 @@ def children_repr(self):
124129
all_modules = c_f.attrs_of_type(self, torch.nn.Module)
125130
return c_f.assert_dicts_are_disjoint(all_hooks, all_modules)
126131

127-
def str_for_error_msg(self, x=None, n=None):
128-
e = str(self if x is None else x)
129-
if n is not None:
130-
e = "\n".join(e.split("\n")[:n])
131-
e += "\n...\n"
132-
return f"\nERROR occuring in:\n{e}"
133-
134132
def check_losses_and_outputs(self, outputs, losses, inputs):
135133
check_keys_are_present(self, self.loss_keys, [losses], "loss_keys", "losses")
136134
check_keys_are_present(

src/pytorch_adapt/hooks/domain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def call(self, inputs, losses):
100100
outputs = self.hook(inputs, losses)[0]
101101
labels = self.extract_domain_labels(inputs)
102102
for domain_name, labels in labels.items():
103+
self.logger(f"Computing loss for {domain_name} domain")
103104
[dlogits] = c_f.extract(
104105
[outputs, inputs],
105106
c_f.filter(self.hook.out_keys, f"_dlogits$", [f"^{domain_name}"]),
@@ -113,6 +114,7 @@ def call(self, inputs, losses):
113114
return outputs, losses
114115

115116
def extract_domain_labels(self, inputs):
117+
self.logger("Expecting 'src_domain' and 'target_domain' in inputs")
116118
[src_domain, target_domain] = c_f.extract(
117119
inputs, ["src_domain", "target_domain"]
118120
)

src/pytorch_adapt/hooks/features.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def call(self, inputs, losses):
7373
""""""
7474
outputs = {}
7575
for domain in self.domains:
76+
self.logger(f"Getting {domain}")
7677
detach = self.check_grad_mode(domain)
7778
func = self.mode_detached if detach else self.mode_with_grad
7879
in_keys = c_f.filter(self.in_keys, f"^{domain}")
@@ -129,7 +130,15 @@ def mode_detached(self, inputs, outputs, domain, in_keys):
129130
def add_if_new(
130131
self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain
131132
):
132-
c_f.add_if_new(outputs, full_key, output_vals, inputs, model_name, in_keys)
133+
c_f.add_if_new(
134+
outputs,
135+
full_key,
136+
output_vals,
137+
inputs,
138+
model_name,
139+
in_keys,
140+
logger=self.logger,
141+
)
133142

134143
def create_keys(self, domain, suffix, starting_keys=None, detach=False):
135144
if starting_keys is None:

src/pytorch_adapt/hooks/gvb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def add_if_new(
3030
inputs,
3131
model_name,
3232
in_keys,
33-
other_args=[True],
33+
other_args={"return_bridge": True},
34+
logger=self.logger,
3435
)
3536

3637

src/pytorch_adapt/hooks/logger.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class HookLogger:
2+
def __init__(self, name):
3+
self.name = name
4+
self.reset()
5+
6+
def __call__(self, x):
7+
if self.str:
8+
self.str += "\n"
9+
self.str += f"{self.name}: {x}"
10+
11+
def reset(self):
12+
self.str = ""

src/pytorch_adapt/utils/common_functions.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import errno
22
import glob
3+
import inspect
34
import itertools
45
import json
56
import logging
@@ -57,27 +58,45 @@ def dict_pop_lazy(x, key, *args, **kwargs):
5758
return default(y, *args, **kwargs)
5859

5960

60-
def add_if_new(d, key, x, kwargs, model_name, in_keys, other_args=None):
61-
# if key is list then assume model returns multiple args
61+
def try_use_model(model, model_name, input_vals):
62+
try:
63+
return default(None, model, input_vals, is_none)
64+
except TypeError as e:
65+
add_error_message(
66+
e,
67+
f"\n{model_name}.forward() signature is {inspect.signature(model.forward)}",
68+
)
69+
raise
70+
71+
72+
def assign_to_output(d, key, x, new_x, condition):
73+
if len(x) > 1:
74+
if not is_list_or_tuple(new_x) or len(new_x) != len(x):
75+
raise TypeError(
76+
"if input x and key are lists, then output of model must be a list of the same length"
77+
)
78+
for i in range(len(x)):
79+
if condition(x[i]):
80+
d[key[i]] = new_x[i]
81+
else:
82+
d[key[0]] = new_x
83+
84+
85+
def add_if_new(d, key, x, kwargs, model_name, in_keys, other_args=None, logger=None):
86+
other_args = default(other_args, {})
87+
if logger:
88+
logger(f"Getting output: {key}")
89+
logger(
90+
f"Using model {model_name} with inputs: {', '.join(in_keys + list(other_args.keys()))}"
91+
)
6292
if not is_list_or_tuple(key) or not is_list_or_tuple(x):
6393
raise TypeError("key and x must both be a list or tuple")
6494
condition = is_none
6595
if any(condition(y) for y in x):
6696
model = kwargs[model_name]
67-
input_vals = [kwargs[k] for k in in_keys]
68-
if other_args is not None:
69-
input_vals += other_args
70-
new_x = default(None, model, input_vals, is_none)
71-
if len(x) > 1:
72-
if not is_list_or_tuple(new_x) or len(new_x) != len(x):
73-
raise TypeError(
74-
"if input x and key are lists, then output of model must be a list of the same length"
75-
)
76-
for i in range(len(x)):
77-
if condition(x[i]):
78-
d[key[i]] = new_x[i]
79-
else:
80-
d[key[0]] = new_x
97+
input_vals = [kwargs[k] for k in in_keys] + list(other_args.values())
98+
new_x = try_use_model(model, model_name, input_vals)
99+
assign_to_output(d, key, x, new_x, condition)
81100

82101

83102
def class_default(cls, x, default):
@@ -399,9 +418,19 @@ def attrs_of_type(cls, obj):
399418
return {k: v for k, v in attrs.items() if isinstance(v, obj)}
400419

401420

402-
def append_error_message(e, msg):
421+
# https://stackoverflow.com/a/70114007/16941290
422+
class ErrorMsgWithNewLines(str):
423+
def __repr__(self):
424+
return str(self)
425+
426+
427+
def add_error_message(e, msg, prepend=False):
403428
if len(e.args) >= 1:
404-
e.args = (e.args[0] + msg,) + e.args[1:]
429+
if prepend:
430+
x = (ErrorMsgWithNewLines(msg + e.args[0]),)
431+
else:
432+
x = (ErrorMsgWithNewLines(e.args[0] + msg),)
433+
e.args = x + e.args[1:]
405434

406435

407436
def requires_grad(x, does=True):

src/pytorch_adapt/weighters/base_weighter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def weight_losses(reduction, weights, scale, loss_dict):
1919
losses.append(loss)
2020
components[k] = loss.item()
2121
except Exception as e:
22-
c_f.append_error_message(e, f"\nError occuring with loss key = {k}")
22+
c_f.add_error_message(e, f"\nError occuring with loss key = {k}")
2323
raise
2424

2525
total = reduction(losses)

tests/hooks/test_logger.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pytorch_adapt.hooks import DANNHook, GVBHook
6+
from pytorch_adapt.hooks.logger import HookLogger
7+
8+
9+
class TestHookLogger(unittest.TestCase):
10+
def test_basic(self):
11+
test_str = "test test"
12+
correct_out = f"{__name__}: {test_str}"
13+
logger = HookLogger(__name__)
14+
logger(test_str)
15+
self.assertTrue(logger.str == correct_out)
16+
logger.reset()
17+
self.assertTrue(logger.str == "")
18+
logger(test_str)
19+
self.assertTrue(logger.str == correct_out)
20+
21+
def test_key_errors_dann(self):
22+
hook = DANNHook(opts=[])
23+
data = {
24+
"src_imgs": torch.randn(32, 32),
25+
}
26+
models = {
27+
"G": torch.nn.Linear(32, 10),
28+
"C": torch.nn.Linear(10, 2),
29+
"D": torch.nn.Sequential(torch.nn.Linear(10, 1), torch.nn.Flatten(0)),
30+
}
31+
with self.assertRaises(KeyError) as cm:
32+
hook({**data, **models})
33+
34+
correct_str = (
35+
"in DANNHook: __call__"
36+
"\nin ChainHook: __call__"
37+
"\nin OptimizerHook: __call__"
38+
"\nin ChainHook: __call__"
39+
"\nin FeaturesForDomainLossHook: __call__"
40+
"\nin FeaturesHook: __call__"
41+
"\nFeaturesHook: Getting src"
42+
"\nFeaturesHook: Getting output: ['src_imgs_features']"
43+
"\nFeaturesHook: Using model G with inputs: src_imgs"
44+
"\nFeaturesHook: Getting target"
45+
"\nFeaturesHook: Getting output: ['target_imgs_features']"
46+
"\nFeaturesHook: Using model G with inputs: target_imgs"
47+
"\ntarget_imgs"
48+
)
49+
50+
self.assertTrue(str(cm.exception) == correct_str)
51+
self.assertTrue(hook.logger.str == "")
52+
53+
def test_type_error_gvb(self):
54+
hook = GVBHook(opts=[])
55+
data = {"src_imgs": torch.randn(32, 32), "target_imgs": torch.randn(32, 32)}
56+
models = {
57+
"G": torch.nn.Linear(32, 10),
58+
"C": torch.nn.Linear(10, 2),
59+
"D": torch.nn.Sequential(torch.nn.Linear(10, 1), torch.nn.Flatten(0)),
60+
}
61+
62+
with self.assertRaises(TypeError) as cm:
63+
hook({**data, **models})
64+
65+
correct_str = (
66+
"in GVBHook: __call__"
67+
"\nin ChainHook: __call__"
68+
"\nin OptimizerHook: __call__"
69+
"\nin ChainHook: __call__"
70+
"\nin FeaturesLogitsAndGBridge: __call__"
71+
"\nin GBridgeAndLogitsHook: __call__"
72+
"\nGBridgeAndLogitsHook: Getting src"
73+
"\nGBridgeAndLogitsHook: Getting output: ['src_imgs_features_logits', 'src_imgs_features_gbridge']"
74+
"\nGBridgeAndLogitsHook: Using model C with inputs: src_imgs_features, return_bridge"
75+
"\nforward() takes 2 positional arguments but 3 were given"
76+
"\nC.forward() signature is (input: torch.Tensor) -> torch.Tensor"
77+
)
78+
79+
self.assertTrue(str(cm.exception) == correct_str)
80+
self.assertTrue(hook.logger.str == "")

0 commit comments

Comments
 (0)