99import os
1010import pickle
1111from types import MethodType
12- from typing import TYPE_CHECKING , Any , Callable , Dict , Generator , Iterable , List , Optional , Set , TypeVar , Union
12+ from typing import (
13+ TYPE_CHECKING ,
14+ Any ,
15+ Callable ,
16+ ClassVar ,
17+ Dict ,
18+ Generator ,
19+ Iterable ,
20+ List ,
21+ Optional ,
22+ Protocol ,
23+ cast ,
24+ runtime_checkable ,
25+ )
1326
1427import yaml
1528
1629from . import futures , loaders , utils
17- from .base .utils import call_with_super_check , super_check
1830from .utils import PID_TYPE , SAVED_STATE_TYPE
1931
2032PersistedCheckpoint = collections .namedtuple ('PersistedCheckpoint' , ['pid' , 'tag' ])
@@ -88,10 +100,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N
88100 :return: The loaded Savable instance
89101
90102 """
91- load_context = _ensure_object_loader (load_context , saved_state )
103+ load_context = ensure_object_loader (load_context , saved_state )
92104 assert load_context .loader is not None # required for type checking
93105 try :
94- class_name = Savable . _get_class_name (saved_state )
106+ class_name = SaveUtil . get_class_name (saved_state )
95107 load_cls : Savable = load_context .loader .load_object (class_name )
96108 except KeyError :
97109 raise ValueError ('Class name not found in saved state' )
@@ -380,22 +392,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None:
380392 del self ._checkpoints [pid ]
381393
382394
383- SavableClsType = TypeVar ('SavableClsType' , bound = 'type[Savable]' )
384-
385-
386- def auto_persist (* members : str ) -> Callable [[SavableClsType ], SavableClsType ]:
387- def wrapped (savable : SavableClsType ) -> SavableClsType :
388- if savable ._auto_persist is None :
389- savable ._auto_persist = set ()
390- else :
391- savable ._auto_persist = set (savable ._auto_persist )
392- savable .auto_persist (* members )
393- return savable
394-
395- return wrapped
396-
397-
398- def _ensure_object_loader (context : Optional ['LoadSaveContext' ], saved_state : SAVED_STATE_TYPE ) -> 'LoadSaveContext' :
395+ def ensure_object_loader (context : Optional ['LoadSaveContext' ], saved_state : SAVED_STATE_TYPE ) -> 'LoadSaveContext' :
399396 """
400397 Given a LoadSaveContext this method will ensure that it has a valid class loader
401398 using the following priorities:
@@ -417,7 +414,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
417414 # 2) Try getting from saved_state
418415 default_loader = loaders .get_object_loader ()
419416 try :
420- loader_identifier = Savable .get_custom_meta (saved_state , META__OBJECT_LOADER )
417+ loader_identifier = SaveUtil .get_custom_meta (saved_state , META__OBJECT_LOADER )
421418 except ValueError :
422419 # 3) Fall back to default
423420 loader = default_loader
@@ -436,45 +433,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
436433META__TYPE__SAVABLE : str = 'S'
437434
438435
439- class Savable :
440- CLASS_NAME : str = 'class_name'
441-
442- _auto_persist : Optional [Set [str ]] = None
443- _persist_configured = False
444-
445- @classmethod
446- def auto_persist (cls , * members : str ) -> None :
447- if cls ._auto_persist is None :
448- cls ._auto_persist = set ()
449- cls ._auto_persist .update (members )
450-
451- @classmethod
452- def recreate_from (cls , saved_state : SAVED_STATE_TYPE , load_context : Optional [LoadSaveContext ] = None ) -> 'Savable' :
453- """
454- Recreate a :class:`Savable` from a saved state using an optional load context.
455-
456- :param saved_state: The saved state
457- :param load_context: An optional load context
458-
459- :return: The recreated instance
460-
461- """
462- ...
463-
464- def save (self , save_context : Optional [LoadSaveContext ] = None ) -> SAVED_STATE_TYPE :
465- out_state : SAVED_STATE_TYPE = auto_save (self , save_context )
466-
467- return out_state
468-
469- def _ensure_persist_configured (self ) -> None :
470- if not self ._persist_configured :
471- self ._persist_configured = True
472-
473- # region Metadata getter/setters
474-
436+ class SaveUtil :
475437 @staticmethod
476438 def set_custom_meta (out_state : SAVED_STATE_TYPE , name : str , value : Any ) -> None :
477- user_dict = Savable . _get_create_meta (out_state ).setdefault (META__USER , {})
439+ user_dict = SaveUtil . get_create_meta (out_state ).setdefault (META__USER , {})
478440 user_dict [name ] = value
479441
480442 @staticmethod
@@ -485,47 +447,127 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
485447 raise ValueError (f"Unknown meta key '{ name } '" )
486448
487449 @staticmethod
488- def _get_create_meta (out_state : SAVED_STATE_TYPE ) -> Dict [str , Any ]:
450+ def get_create_meta (out_state : SAVED_STATE_TYPE ) -> Dict [str , Any ]:
489451 return out_state .setdefault (META , {})
490452
491453 @staticmethod
492- def _set_class_name (out_state : SAVED_STATE_TYPE , name : str ) -> None :
493- Savable . _get_create_meta (out_state )[META__CLASS_NAME ] = name
454+ def set_class_name (out_state : SAVED_STATE_TYPE , name : str ) -> None :
455+ SaveUtil . get_create_meta (out_state )[META__CLASS_NAME ] = name
494456
495457 @staticmethod
496- def _get_class_name (saved_state : SAVED_STATE_TYPE ) -> str :
497- return Savable . _get_create_meta (saved_state )[META__CLASS_NAME ]
458+ def get_class_name (saved_state : SAVED_STATE_TYPE ) -> str :
459+ return SaveUtil . get_create_meta (saved_state )[META__CLASS_NAME ]
498460
499461 @staticmethod
500- def _set_meta_type (out_state : SAVED_STATE_TYPE , name : str , type_spec : Any ) -> None :
501- type_dict = Savable . _get_create_meta (out_state ).setdefault (META__TYPES , {})
462+ def set_meta_type (out_state : SAVED_STATE_TYPE , name : str , type_spec : Any ) -> None :
463+ type_dict = SaveUtil . get_create_meta (out_state ).setdefault (META__TYPES , {})
502464 type_dict [name ] = type_spec
503465
504466 @staticmethod
505- def _get_meta_type (saved_state : SAVED_STATE_TYPE , name : str ) -> Any :
467+ def get_meta_type (saved_state : SAVED_STATE_TYPE , name : str ) -> Any :
506468 try :
507469 return saved_state [META ][META__TYPES ][name ]
508470 except KeyError :
509471 pass
510472
511- # endregion
512473
513- def _get_value (
514- self , saved_state : SAVED_STATE_TYPE , name : str , load_context : Optional [LoadSaveContext ]
515- ) -> Union [MethodType , 'Savable' ]:
516- value = saved_state [name ]
474+ @runtime_checkable
475+ class Savable (Protocol ):
476+ @classmethod
477+ def recreate_from (cls , saved_state : SAVED_STATE_TYPE , load_context : LoadSaveContext | None = None ) -> 'Savable' :
478+ """
479+ Recreate a :class:`Savable` from a saved state using an optional load context.
480+
481+ :param saved_state: The saved state
482+ :param load_context: An optional load context
483+
484+ :return: The recreated instance
485+
486+ """
487+ ...
488+
489+ def save (self , save_context : LoadSaveContext | None = None ) -> SAVED_STATE_TYPE : ...
490+
491+
492+ @runtime_checkable
493+ class SavableWithAutoPersist (Savable , Protocol ):
494+ _auto_persist : ClassVar [set [str ]] = set ()
495+
496+
497+ def auto_save (obj : Savable , save_context : Optional [LoadSaveContext ] = None ) -> SAVED_STATE_TYPE :
498+ out_state : SAVED_STATE_TYPE = {}
499+
500+ if save_context is None :
501+ save_context = LoadSaveContext ()
502+
503+ utils .type_check (save_context , LoadSaveContext )
504+
505+ default_loader = loaders .get_object_loader ()
506+ # If the user has specified a class loader, then save it in the saved state
507+ if save_context .loader is not None :
508+ loader_class = default_loader .identify_object (save_context .loader .__class__ )
509+ SaveUtil .set_custom_meta (out_state , META__OBJECT_LOADER , loader_class )
510+ loader = save_context .loader
511+ else :
512+ loader = default_loader
513+
514+ SaveUtil .set_class_name (out_state , loader .identify_object (obj .__class__ ))
515+
516+ if isinstance (obj , SavableWithAutoPersist ):
517+ for member in obj ._auto_persist :
518+ value = getattr (obj , member )
519+ if inspect .ismethod (value ):
520+ if value .__self__ is not obj :
521+ raise TypeError ('Cannot persist methods of other classes' )
522+ SaveUtil .set_meta_type (out_state , member , META__TYPE__METHOD )
523+ value = value .__name__
524+ elif isinstance (value , Savable ) and not isinstance (value , type ):
525+ # persist for a savable obj, call `save` method of obj.
526+ SaveUtil .set_meta_type (out_state , member , META__TYPE__SAVABLE )
527+ value = value .save ()
528+ else :
529+ value = copy .deepcopy (value )
530+ out_state [member ] = value
531+
532+ return out_state
533+
534+
535+ def auto_load (obj : SavableWithAutoPersist , saved_state : SAVED_STATE_TYPE , load_context : LoadSaveContext ) -> None :
536+ for member in obj ._auto_persist :
537+ setattr (obj , member , _get_value (obj , saved_state , member , load_context ))
538+
517539
518- typ = Savable ._get_meta_type (saved_state , name )
519- if typ == META__TYPE__METHOD :
520- value = getattr (self , value )
521- elif typ == META__TYPE__SAVABLE :
522- value = load (value , load_context )
540+ def _get_value (
541+ obj : Any , saved_state : SAVED_STATE_TYPE , name : str , load_context : LoadSaveContext | None
542+ ) -> MethodType | Savable :
543+ value = saved_state [name ]
523544
524- return value
545+ typ = SaveUtil .get_meta_type (saved_state , name )
546+ if typ == META__TYPE__METHOD :
547+ value = getattr (obj , value )
548+ elif typ == META__TYPE__SAVABLE :
549+ value = load (value , load_context )
550+
551+ return value
552+
553+
554+ def auto_persist (* members : str ) -> Callable [..., Savable ]:
555+ def wrapped (savable_cls : type ) -> Savable :
556+ if not hasattr (savable_cls , '_auto_persist' ) or savable_cls ._auto_persist is None :
557+ savable_cls ._auto_persist = set () # type: ignore[attr-defined]
558+ else :
559+ savable_cls ._auto_persist = set (savable_cls ._auto_persist )
525560
561+ savable_cls ._auto_persist .update (members ) # type: ignore[attr-defined]
562+ # XXX: validate on `save` and `recreate_from` method??
563+ return cast (Savable , savable_cls )
526564
565+ return wrapped
566+
567+
568+ # FIXME: move me to another module? savablefuture.py?
527569@auto_persist ('_state' , '_result' )
528- class SavableFuture (futures .Future , Savable ):
570+ class SavableFuture (futures .Future ):
529571 """
530572 A savable future.
531573
@@ -550,7 +592,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
550592 :return: The recreated instance
551593
552594 """
553- load_context = _ensure_object_loader (load_context , saved_state )
595+ load_context = ensure_object_loader (load_context , saved_state )
554596
555597 try :
556598 loop = load_context .loop
@@ -586,48 +628,3 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
586628 # ## UNTILHERE XXX:
587629
588630 return obj
589-
590-
591- def auto_save (obj : Savable , save_context : Optional [LoadSaveContext ] = None ) -> SAVED_STATE_TYPE :
592- out_state : SAVED_STATE_TYPE = {}
593-
594- if save_context is None :
595- save_context = LoadSaveContext ()
596-
597- utils .type_check (save_context , LoadSaveContext )
598-
599- default_loader = loaders .get_object_loader ()
600- # If the user has specified a class loader, then save it in the saved state
601- if save_context .loader is not None :
602- loader_class = default_loader .identify_object (save_context .loader .__class__ )
603- Savable .set_custom_meta (out_state , META__OBJECT_LOADER , loader_class )
604- loader = save_context .loader
605- else :
606- loader = default_loader
607-
608- Savable ._set_class_name (out_state , loader .identify_object (obj .__class__ ))
609-
610- obj ._ensure_persist_configured ()
611- if obj ._auto_persist is not None :
612- for member in obj ._auto_persist :
613- value = getattr (obj , member )
614- if inspect .ismethod (value ):
615- if value .__self__ is not obj :
616- raise TypeError ('Cannot persist methods of other classes' )
617- Savable ._set_meta_type (out_state , member , META__TYPE__METHOD )
618- value = value .__name__
619- elif isinstance (value , Savable ):
620- Savable ._set_meta_type (out_state , member , META__TYPE__SAVABLE )
621- value = value .save ()
622- else :
623- value = copy .deepcopy (value )
624- out_state [member ] = value
625-
626- return out_state
627-
628-
629- def auto_load (obj : Savable , saved_state : SAVED_STATE_TYPE , load_context : LoadSaveContext ) -> None :
630- obj ._ensure_persist_configured ()
631- if obj ._auto_persist is not None :
632- for member in obj ._auto_persist :
633- setattr (obj , member , obj ._get_value (saved_state , member , load_context ))
0 commit comments