@@ -198,72 +198,70 @@ def __repr__(self):
198198 return out
199199
200200
201- def lightning_hasattr (model , attribute ):
202- """ Special hasattr for lightning. Checks for attribute in model namespace,
203- the old hparams namespace/dict, and the datamodule. """
201+ def lightning_get_all_attr_holders (model , attribute ):
202+ """ Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
203+ Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
204204 trainer = getattr (model , 'trainer' , None )
205205
206- attr = False
206+ holders = []
207+
207208 # Check if attribute in model
208209 if hasattr (model , attribute ):
209- attr = True
210+ holders .append (model )
211+
210212 # Check if attribute in model.hparams, either namespace or dict
211- elif hasattr (model , 'hparams' ):
212- if isinstance (model .hparams , dict ):
213- attr = attribute in model .hparams
214- else :
215- attr = hasattr (model .hparams , attribute )
213+ if hasattr (model , 'hparams' ):
214+ if attribute in model .hparams :
215+ holders .append (model .hparams )
216+
216217 # Check if the attribute in datamodule (datamodule gets registered in Trainer)
217- if not attr and trainer is not None :
218- attr = hasattr (trainer .datamodule , attribute )
218+ if trainer is not None and trainer . datamodule is not None and hasattr ( trainer . datamodule , attribute ) :
219+ holders . append (trainer .datamodule )
219220
220- return attr
221+ return holders
222+
223+
224+ def lightning_get_first_attr_holder (model , attribute ):
225+ """ Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
226+ the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
227+ holders = lightning_get_all_attr_holders (model , attribute )
228+ if len (holders ) == 0 :
229+ return None
230+ # using the last holder to preserve backwards compatibility
231+ return holders [- 1 ]
232+
233+
234+ def lightning_hasattr (model , attribute ):
235+ """ Special hasattr for lightning. Checks for attribute in model namespace,
236+ the old hparams namespace/dict, and the datamodule. """
237+ return lightning_get_first_attr_holder (model , attribute ) is not None
221238
222239
223240def lightning_getattr (model , attribute ):
224241 """ Special getattr for lightning. Checks for attribute in model namespace,
225242 the old hparams namespace/dict, and the datamodule. """
226- trainer = getattr (model , 'trainer' , None )
227-
228- # Check if attribute in model
229- if hasattr (model , attribute ):
230- attr = getattr (model , attribute )
231- # Check if attribute in model.hparams, either namespace or dict
232- elif hasattr (model , 'hparams' ) and isinstance (model .hparams , dict ) and attribute in model .hparams :
233- attr = model .hparams [attribute ]
234- elif hasattr (model , 'hparams' ) and hasattr (model .hparams , attribute ):
235- attr = getattr (model .hparams , attribute )
236- # Check if the attribute in datamodule (datamodule gets registered in Trainer)
237- elif trainer is not None and trainer .datamodule is not None and hasattr (trainer .datamodule , attribute ):
238- attr = getattr (trainer .datamodule , attribute )
239- else :
243+ holder = lightning_get_first_attr_holder (model , attribute )
244+ if holder is None :
240245 raise ValueError (f'{ attribute } is neither stored in the model namespace'
241246 ' nor the `hparams` namespace/dict, nor the datamodule.' )
242- return attr
247+
248+ if isinstance (holder , dict ):
249+ return holder [attribute ]
250+ return getattr (holder , attribute )
243251
244252
245253def lightning_setattr (model , attribute , value ):
246254 """ Special setattr for lightning. Checks for attribute in model namespace
247255 and the old hparams namespace/dict.
248256 Will also set the attribute on datamodule, if it exists.
249257 """
250- if not lightning_hasattr (model , attribute ):
258+ holders = lightning_get_all_attr_holders (model , attribute )
259+ if len (holders ) == 0 :
251260 raise ValueError (f'{ attribute } is neither stored in the model namespace'
252261 ' nor the `hparams` namespace/dict, nor the datamodule.' )
253262
254- trainer = getattr (model , 'trainer' , None )
255-
256- # Check if attribute in model
257- if hasattr (model , attribute ):
258- setattr (model , attribute , value )
259-
260- # Check if attribute in model.hparams, either namespace or dict
261- elif hasattr (model , 'hparams' ):
262- if isinstance (model .hparams , dict ):
263- model .hparams [attribute ] = value
263+ for holder in holders :
264+ if isinstance (holder , dict ):
265+ holder [attribute ] = value
264266 else :
265- setattr (model .hparams , attribute , value )
266-
267- # Check if the attribute in datamodule (datamodule gets registered in Trainer)
268- if trainer is not None and trainer .datamodule is not None and hasattr (trainer .datamodule , attribute ):
269- setattr (trainer .datamodule , attribute , value )
267+ setattr (holder , attribute , value )
0 commit comments