Skip to content

Commit eee3b1a

Browse files
rnettBordaSeanNaren
authored
Unify attribute finding logic, fix not using dataloader when hparams present (#4559)
* Rebase onto master * indent fix * Remove duplicated logic * Use single return * Remove extra else * add `__contains__` to TestHparamsNamespace to fix tests * Fix lightning_setattr to set all valid attributes * update doc * better names * fix holder order preference * tests for new behavior * Comment about using the last holder Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent c76cc23 commit eee3b1a

File tree

2 files changed

+77
-48
lines changed

2 files changed

+77
-48
lines changed

pytorch_lightning/utilities/parsing.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -198,72 +198,70 @@ def __repr__(self):
198198
return out
199199

200200

201-
def lightning_hasattr(model, attribute):
202-
""" Special hasattr for lightning. Checks for attribute in model namespace,
203-
the old hparams namespace/dict, and the datamodule. """
201+
def lightning_get_all_attr_holders(model, attribute):
202+
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
203+
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
204204
trainer = getattr(model, 'trainer', None)
205205

206-
attr = False
206+
holders = []
207+
207208
# Check if attribute in model
208209
if hasattr(model, attribute):
209-
attr = True
210+
holders.append(model)
211+
210212
# Check if attribute in model.hparams, either namespace or dict
211-
elif hasattr(model, 'hparams'):
212-
if isinstance(model.hparams, dict):
213-
attr = attribute in model.hparams
214-
else:
215-
attr = hasattr(model.hparams, attribute)
213+
if hasattr(model, 'hparams'):
214+
if attribute in model.hparams:
215+
holders.append(model.hparams)
216+
216217
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
217-
if not attr and trainer is not None:
218-
attr = hasattr(trainer.datamodule, attribute)
218+
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
219+
holders.append(trainer.datamodule)
219220

220-
return attr
221+
return holders
222+
223+
224+
def lightning_get_first_attr_holder(model, attribute):
225+
""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
226+
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
227+
holders = lightning_get_all_attr_holders(model, attribute)
228+
if len(holders) == 0:
229+
return None
230+
# using the last holder to preserve backwards compatibility
231+
return holders[-1]
232+
233+
234+
def lightning_hasattr(model, attribute):
235+
""" Special hasattr for lightning. Checks for attribute in model namespace,
236+
the old hparams namespace/dict, and the datamodule. """
237+
return lightning_get_first_attr_holder(model, attribute) is not None
221238

222239

223240
def lightning_getattr(model, attribute):
224241
""" Special getattr for lightning. Checks for attribute in model namespace,
225242
the old hparams namespace/dict, and the datamodule. """
226-
trainer = getattr(model, 'trainer', None)
227-
228-
# Check if attribute in model
229-
if hasattr(model, attribute):
230-
attr = getattr(model, attribute)
231-
# Check if attribute in model.hparams, either namespace or dict
232-
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
233-
attr = model.hparams[attribute]
234-
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
235-
attr = getattr(model.hparams, attribute)
236-
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
237-
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
238-
attr = getattr(trainer.datamodule, attribute)
239-
else:
243+
holder = lightning_get_first_attr_holder(model, attribute)
244+
if holder is None:
240245
raise ValueError(f'{attribute} is neither stored in the model namespace'
241246
' nor the `hparams` namespace/dict, nor the datamodule.')
242-
return attr
247+
248+
if isinstance(holder, dict):
249+
return holder[attribute]
250+
return getattr(holder, attribute)
243251

244252

245253
def lightning_setattr(model, attribute, value):
246254
""" Special setattr for lightning. Checks for attribute in model namespace
247255
and the old hparams namespace/dict.
248256
Will also set the attribute on datamodule, if it exists.
249257
"""
250-
if not lightning_hasattr(model, attribute):
258+
holders = lightning_get_all_attr_holders(model, attribute)
259+
if len(holders) == 0:
251260
raise ValueError(f'{attribute} is neither stored in the model namespace'
252261
' nor the `hparams` namespace/dict, nor the datamodule.')
253262

254-
trainer = getattr(model, 'trainer', None)
255-
256-
# Check if attribute in model
257-
if hasattr(model, attribute):
258-
setattr(model, attribute, value)
259-
260-
# Check if attribute in model.hparams, either namespace or dict
261-
elif hasattr(model, 'hparams'):
262-
if isinstance(model.hparams, dict):
263-
model.hparams[attribute] = value
263+
for holder in holders:
264+
if isinstance(holder, dict):
265+
holder[attribute] = value
264266
else:
265-
setattr(model.hparams, attribute, value)
266-
267-
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
268-
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
269-
setattr(trainer.datamodule, attribute, value)
267+
setattr(holder, attribute, value)

tests/utilities/test_parsing.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def _get_test_cases():
2020
class TestHparamsNamespace:
2121
learning_rate = 1
2222

23+
def __contains__(self, item):
24+
return item == "learning_rate"
25+
2326
TestHparamsDict = {'learning_rate': 2}
2427

2528
class TestModel1: # test for namespace
@@ -53,12 +56,26 @@ class TestModel5: # test for datamodule
5356

5457
model5 = TestModel5()
5558

56-
return model1, model2, model3, model4, model5
59+
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
60+
trainer = Trainer
61+
hparams = TestHparamsDict
62+
63+
model6 = TestModel6()
64+
65+
TestHparamsDict2 = {'batch_size': 2}
66+
67+
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
68+
trainer = Trainer
69+
hparams = TestHparamsDict2
70+
71+
model7 = TestModel7()
72+
73+
return model1, model2, model3, model4, model5, model6, model7
5774

5875

5976
def test_lightning_hasattr(tmpdir):
6077
""" Test that the lightning_hasattr works in all cases"""
61-
model1, model2, model3, model4, model5 = _get_test_cases()
78+
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
6279
assert lightning_hasattr(model1, 'learning_rate'), \
6380
'lightning_hasattr failed to find namespace variable'
6481
assert lightning_hasattr(model2, 'learning_rate'), \
@@ -69,6 +86,10 @@ def test_lightning_hasattr(tmpdir):
6986
'lightning_hasattr found variable when it should not'
7087
assert lightning_hasattr(model5, 'batch_size'), \
7188
'lightning_hasattr failed to find batch_size in datamodule'
89+
assert lightning_hasattr(model6, 'batch_size'), \
90+
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
91+
assert lightning_hasattr(model7, 'batch_size'), \
92+
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
7293

7394

7495
def test_lightning_getattr(tmpdir):
@@ -78,9 +99,13 @@ def test_lightning_getattr(tmpdir):
7899
value = lightning_getattr(m, 'learning_rate')
79100
assert value == i, 'attribute not correctly extracted'
80101

81-
model5 = models[4]
102+
model5, model6, model7 = models[4:]
82103
assert lightning_getattr(model5, 'batch_size') == 8, \
83104
'batch_size not correctly extracted'
105+
assert lightning_getattr(model6, 'batch_size') == 8, \
106+
'batch_size not correctly extracted'
107+
assert lightning_getattr(model7, 'batch_size') == 8, \
108+
'batch_size not correctly extracted'
84109

85110

86111
def test_lightning_setattr(tmpdir):
@@ -91,7 +116,13 @@ def test_lightning_setattr(tmpdir):
91116
assert lightning_getattr(m, 'learning_rate') == 10, \
92117
'attribute not correctly set'
93118

94-
model5 = models[4]
119+
model5, model6, model7 = models[4:]
95120
lightning_setattr(model5, 'batch_size', 128)
121+
lightning_setattr(model6, 'batch_size', 128)
122+
lightning_setattr(model7, 'batch_size', 128)
96123
assert lightning_getattr(model5, 'batch_size') == 128, \
97124
'batch_size not correctly set'
125+
assert lightning_getattr(model6, 'batch_size') == 128, \
126+
'batch_size not correctly set'
127+
assert lightning_getattr(model7, 'batch_size') == 128, \
128+
'batch_size not correctly set'

0 commit comments

Comments
 (0)