2424from ema_pytorch import EMA
2525
2626from lightning import Fabric
27+ from lightning .fabric .wrappers import _unwrap_objects
2728
2829# constants
2930
@@ -341,9 +342,12 @@ def save(
341342
342343 path .parent .mkdir (exist_ok = True , parents = True )
343344
345+ unwrapped_model = _unwrap_objects (self .model )
346+ unwrapped_optimizer = _unwrap_objects (self .optimizer )
347+
344348 package = dict (
345- model = self . model .state_dict_with_init_args ,
346- optimizer = self . optimizer .state_dict (),
349+ model = unwrapped_model .state_dict_with_init_args ,
350+ optimizer = unwrapped_optimizer .state_dict (),
347351 scheduler = self .scheduler .state_dict (),
348352 steps = self .steps
349353 )
@@ -379,9 +383,14 @@ def load(
379383
380384 self .model_loaded_from_path = path
381385
386+ # get unwrapped model and optimizer
387+
388+ unwrapped_model = _unwrap_objects (self .model )
389+ unwrapped_optimizer = _unwrap_objects (self .optimizer )
390+
382391 # load model from path
383392
384- self . model .load (path )
393+ unwrapped_model .load (path )
385394
386395 if only_model :
387396 return
@@ -391,7 +400,7 @@ def load(
391400 package = torch .load (str (path ))
392401
393402 if 'optimizer' in package :
394- self . optimizer .load_state_dict (package ['optimizer' ])
403+ unwrapped_optimizer .load_state_dict (package ['optimizer' ])
395404
396405 if 'scheduler' in package :
397406 self .scheduler .load_state_dict (package ['scheduler' ])
0 commit comments