Skip to content

Commit 23abe62

Browse files
committed
forming Savable protocol
- remove persist_config flag of savable
1 parent 558e2a1 commit 23abe62

File tree

9 files changed

+278
-219
lines changed

9 files changed

+278
-219
lines changed

src/plumpy/event_helper.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22
import logging
33
from typing import TYPE_CHECKING, Any, Callable, Optional
44

5+
from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader
56
from plumpy.utils import SAVED_STATE_TYPE
67

78
from . import persistence
8-
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load
99

1010
if TYPE_CHECKING:
1111
from typing import Set, Type
12+
1213
from .process_listener import ProcessListener
1314

1415
_LOGGER = logging.getLogger(__name__)
1516

1617

1718
@persistence.auto_persist('_listeners', '_listener_type')
18-
class EventHelper(persistence.Savable):
19+
class EventHelper:
1920
def __init__(self, listener_type: 'Type[ProcessListener]'):
2021
assert listener_type is not None, 'Must provide valid listener type'
2122

@@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
4344
:return: The recreated instance
4445
4546
"""
46-
load_context = _ensure_object_loader(load_context, saved_state)
47+
load_context = ensure_object_loader(load_context, saved_state)
4748
obj = cls.__new__(cls)
4849
auto_load(obj, saved_state, load_context)
4950
return obj
5051

52+
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
53+
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)
54+
55+
return out_state
56+
5157
@property
5258
def listeners(self) -> 'Set[ProcessListener]':
5359
return self._listeners

src/plumpy/persistence.py

Lines changed: 121 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,24 @@
99
import os
1010
import pickle
1111
from types import MethodType
12-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
Callable,
16+
ClassVar,
17+
Dict,
18+
Generator,
19+
Iterable,
20+
List,
21+
Optional,
22+
Protocol,
23+
cast,
24+
runtime_checkable,
25+
)
1326

1427
import yaml
1528

1629
from . import futures, loaders, utils
17-
from .base.utils import call_with_super_check, super_check
1830
from .utils import PID_TYPE, SAVED_STATE_TYPE
1931

2032
PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag'])
@@ -88,10 +100,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N
88100
:return: The loaded Savable instance
89101
90102
"""
91-
load_context = _ensure_object_loader(load_context, saved_state)
103+
load_context = ensure_object_loader(load_context, saved_state)
92104
assert load_context.loader is not None # required for type checking
93105
try:
94-
class_name = Savable._get_class_name(saved_state)
106+
class_name = SaveUtil.get_class_name(saved_state)
95107
load_cls: Savable = load_context.loader.load_object(class_name)
96108
except KeyError:
97109
raise ValueError('Class name not found in saved state')
@@ -380,22 +392,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None:
380392
del self._checkpoints[pid]
381393

382394

383-
SavableClsType = TypeVar('SavableClsType', bound='type[Savable]')
384-
385-
386-
def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]:
387-
def wrapped(savable: SavableClsType) -> SavableClsType:
388-
if savable._auto_persist is None:
389-
savable._auto_persist = set()
390-
else:
391-
savable._auto_persist = set(savable._auto_persist)
392-
savable.auto_persist(*members)
393-
return savable
394-
395-
return wrapped
396-
397-
398-
def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
395+
def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
399396
"""
400397
Given a LoadSaveContext this method will ensure that it has a valid class loader
401398
using the following priorities:
@@ -417,7 +414,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
417414
# 2) Try getting from saved_state
418415
default_loader = loaders.get_object_loader()
419416
try:
420-
loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER)
417+
loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER)
421418
except ValueError:
422419
# 3) Fall back to default
423420
loader = default_loader
@@ -436,45 +433,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
436433
META__TYPE__SAVABLE: str = 'S'
437434

438435

439-
class Savable:
440-
CLASS_NAME: str = 'class_name'
441-
442-
_auto_persist: Optional[Set[str]] = None
443-
_persist_configured = False
444-
445-
@classmethod
446-
def auto_persist(cls, *members: str) -> None:
447-
if cls._auto_persist is None:
448-
cls._auto_persist = set()
449-
cls._auto_persist.update(members)
450-
451-
@classmethod
452-
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
453-
"""
454-
Recreate a :class:`Savable` from a saved state using an optional load context.
455-
456-
:param saved_state: The saved state
457-
:param load_context: An optional load context
458-
459-
:return: The recreated instance
460-
461-
"""
462-
...
463-
464-
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
465-
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)
466-
467-
return out_state
468-
469-
def _ensure_persist_configured(self) -> None:
470-
if not self._persist_configured:
471-
self._persist_configured = True
472-
473-
# region Metadata getter/setters
474-
436+
class SaveUtil:
475437
@staticmethod
476438
def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None:
477-
user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {})
439+
user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {})
478440
user_dict[name] = value
479441

480442
@staticmethod
@@ -485,47 +447,127 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
485447
raise ValueError(f"Unknown meta key '{name}'")
486448

487449
@staticmethod
488-
def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
450+
def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
489451
return out_state.setdefault(META, {})
490452

491453
@staticmethod
492-
def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
493-
Savable._get_create_meta(out_state)[META__CLASS_NAME] = name
454+
def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
455+
SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name
494456

495457
@staticmethod
496-
def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
497-
return Savable._get_create_meta(saved_state)[META__CLASS_NAME]
458+
def get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
459+
return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME]
498460

499461
@staticmethod
500-
def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
501-
type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {})
462+
def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
463+
type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {})
502464
type_dict[name] = type_spec
503465

504466
@staticmethod
505-
def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
467+
def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
506468
try:
507469
return saved_state[META][META__TYPES][name]
508470
except KeyError:
509471
pass
510472

511-
# endregion
512473

513-
def _get_value(
514-
self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
515-
) -> Union[MethodType, 'Savable']:
516-
value = saved_state[name]
474+
@runtime_checkable
475+
class Savable(Protocol):
476+
@classmethod
477+
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable':
478+
"""
479+
Recreate a :class:`Savable` from a saved state using an optional load context.
480+
481+
:param saved_state: The saved state
482+
:param load_context: An optional load context
483+
484+
:return: The recreated instance
485+
486+
"""
487+
...
488+
489+
def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...
490+
491+
492+
@runtime_checkable
493+
class SavableWithAutoPersist(Savable, Protocol):
494+
_auto_persist: ClassVar[set[str]] = set()
495+
496+
497+
def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
498+
out_state: SAVED_STATE_TYPE = {}
499+
500+
if save_context is None:
501+
save_context = LoadSaveContext()
502+
503+
utils.type_check(save_context, LoadSaveContext)
504+
505+
default_loader = loaders.get_object_loader()
506+
# If the user has specified a class loader, then save it in the saved state
507+
if save_context.loader is not None:
508+
loader_class = default_loader.identify_object(save_context.loader.__class__)
509+
SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
510+
loader = save_context.loader
511+
else:
512+
loader = default_loader
513+
514+
SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__))
515+
516+
if isinstance(obj, SavableWithAutoPersist):
517+
for member in obj._auto_persist:
518+
value = getattr(obj, member)
519+
if inspect.ismethod(value):
520+
if value.__self__ is not obj:
521+
raise TypeError('Cannot persist methods of other classes')
522+
SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD)
523+
value = value.__name__
524+
elif isinstance(value, Savable) and not isinstance(value, type):
525+
# persist for a savable obj, call `save` method of obj.
526+
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
527+
value = value.save()
528+
else:
529+
value = copy.deepcopy(value)
530+
out_state[member] = value
531+
532+
return out_state
533+
534+
535+
def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
536+
for member in obj._auto_persist:
537+
setattr(obj, member, _get_value(obj, saved_state, member, load_context))
538+
517539

518-
typ = Savable._get_meta_type(saved_state, name)
519-
if typ == META__TYPE__METHOD:
520-
value = getattr(self, value)
521-
elif typ == META__TYPE__SAVABLE:
522-
value = load(value, load_context)
540+
def _get_value(
541+
obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None
542+
) -> MethodType | Savable:
543+
value = saved_state[name]
523544

524-
return value
545+
typ = SaveUtil.get_meta_type(saved_state, name)
546+
if typ == META__TYPE__METHOD:
547+
value = getattr(obj, value)
548+
elif typ == META__TYPE__SAVABLE:
549+
value = load(value, load_context)
550+
551+
return value
552+
553+
554+
def auto_persist(*members: str) -> Callable[..., Savable]:
555+
def wrapped(savable_cls: type) -> Savable:
556+
if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None:
557+
savable_cls._auto_persist = set() # type: ignore[attr-defined]
558+
else:
559+
savable_cls._auto_persist = set(savable_cls._auto_persist)
525560

561+
savable_cls._auto_persist.update(members) # type: ignore[attr-defined]
562+
# XXX: validate on `save` and `recreate_from` method??
563+
return cast(Savable, savable_cls)
526564

565+
return wrapped
566+
567+
568+
# FIXME: move me to another module? savablefuture.py?
527569
@auto_persist('_state', '_result')
528-
class SavableFuture(futures.Future, Savable):
570+
class SavableFuture(futures.Future):
529571
"""
530572
A savable future.
531573
@@ -550,7 +592,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
550592
:return: The recreated instance
551593
552594
"""
553-
load_context = _ensure_object_loader(load_context, saved_state)
595+
load_context = ensure_object_loader(load_context, saved_state)
554596

555597
try:
556598
loop = load_context.loop
@@ -586,48 +628,3 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
586628
# ## UNTILHERE XXX:
587629

588630
return obj
589-
590-
591-
def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
592-
out_state: SAVED_STATE_TYPE = {}
593-
594-
if save_context is None:
595-
save_context = LoadSaveContext()
596-
597-
utils.type_check(save_context, LoadSaveContext)
598-
599-
default_loader = loaders.get_object_loader()
600-
# If the user has specified a class loader, then save it in the saved state
601-
if save_context.loader is not None:
602-
loader_class = default_loader.identify_object(save_context.loader.__class__)
603-
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
604-
loader = save_context.loader
605-
else:
606-
loader = default_loader
607-
608-
Savable._set_class_name(out_state, loader.identify_object(obj.__class__))
609-
610-
obj._ensure_persist_configured()
611-
if obj._auto_persist is not None:
612-
for member in obj._auto_persist:
613-
value = getattr(obj, member)
614-
if inspect.ismethod(value):
615-
if value.__self__ is not obj:
616-
raise TypeError('Cannot persist methods of other classes')
617-
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
618-
value = value.__name__
619-
elif isinstance(value, Savable):
620-
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
621-
value = value.save()
622-
else:
623-
value = copy.deepcopy(value)
624-
out_state[member] = value
625-
626-
return out_state
627-
628-
629-
def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
630-
obj._ensure_persist_configured()
631-
if obj._auto_persist is not None:
632-
for member in obj._auto_persist:
633-
setattr(obj, member, obj._get_value(saved_state, member, load_context))

0 commit comments

Comments
 (0)