Skip to content

Commit 629832c

Browse files
committed
Make auto_load symmetry with auto_save and state/state_label distinguish
1 parent 23abe62 commit 629832c

File tree

11 files changed

+146
-135
lines changed

11 files changed

+146
-135
lines changed

src/plumpy/base/state_machine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
266266
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)
267267

268268
@property
269-
def state(self) -> Any:
269+
def state(self) -> State | None:
270+
if self._state is None:
271+
return None
272+
return self._state
273+
274+
@property
275+
def state_label(self) -> Any:
270276
if self._state is None:
271277
return None
272278
return self._state.LABEL
@@ -314,7 +320,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
314320
# it can happened when transit from terminal state
315321
return None
316322

317-
initial_state_label = self._state.LABEL if self._state is not None else None
323+
initial_state_label = self.state_label
318324
label = None
319325
try:
320326
self._transitioning = True

src/plumpy/event_helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
4545
4646
"""
4747
load_context = ensure_object_loader(load_context, saved_state)
48-
obj = cls.__new__(cls)
49-
auto_load(obj, saved_state, load_context)
48+
obj = auto_load(cls, saved_state, load_context)
5049
return obj
5150

5251
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:

src/plumpy/persistence.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
List,
2121
Optional,
2222
Protocol,
23+
TypeVar,
2324
cast,
2425
runtime_checkable,
2526
)
@@ -523,6 +524,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
523524
value = value.__name__
524525
elif isinstance(value, Savable) and not isinstance(value, type):
525526
# persist for a savable obj, call `save` method of obj.
527+
# the rhs branch is for when value is a Savable class, it is true runtime check
528+
# of lhs condition.
526529
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
527530
value = value.save()
528531
else:
@@ -532,11 +535,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
532535
return out_state
533536

534537

535-
def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
538+
def load_auto_persist_params(
539+
obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None
540+
) -> None:
536541
for member in obj._auto_persist:
537542
setattr(obj, member, _get_value(obj, saved_state, member, load_context))
538543

539544

545+
T = TypeVar('T', bound=Savable)
546+
547+
548+
def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T:
549+
obj = cls.__new__(cls)
550+
551+
if isinstance(obj, SavableWithAutoPersist):
552+
load_auto_persist_params(obj, saved_state, load_context)
553+
554+
return obj
555+
556+
540557
def _get_value(
541558
obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None
542559
) -> MethodType | Savable:

src/plumpy/process_states.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
auto_load,
4343
auto_persist,
4444
auto_save,
45+
ensure_object_loader,
4546
)
4647
from .utils import SAVED_STATE_TYPE
4748

@@ -102,8 +103,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
102103
:return: The recreated instance
103104
104105
"""
105-
obj = cls.__new__(cls)
106-
auto_load(obj, saved_state, load_context)
106+
load_context = ensure_object_loader(load_context, saved_state)
107+
obj = auto_load(cls, saved_state, load_context)
107108
return obj
108109

109110
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
@@ -175,15 +176,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
175176
176177
"""
177178
load_context = ensure_object_loader(load_context, saved_state)
178-
obj = cls.__new__(cls)
179-
auto_load(obj, saved_state, load_context)
179+
obj = auto_load(cls, saved_state, load_context)
180180

181-
obj.state_machine = load_context.process
182181
try:
183182
obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN])
184183
except ValueError:
185-
process = load_context.process
186-
obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN])
184+
if load_context is not None:
185+
obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN])
186+
else:
187+
raise
187188
return obj
188189

189190

@@ -239,12 +240,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
239240
240241
"""
241242
load_context = ensure_object_loader(load_context, saved_state)
242-
obj = cls.__new__(cls)
243-
244-
auto_load(obj, saved_state, load_context)
245-
243+
obj = auto_load(cls, saved_state, load_context)
246244
obj.process = load_context.process
247-
248245
obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
249246

250247
return obj
@@ -312,15 +309,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
312309
313310
"""
314311
load_context = ensure_object_loader(load_context, saved_state)
315-
obj = cls.__new__(cls)
316-
auto_load(obj, saved_state, load_context)
317-
312+
obj = auto_load(cls, saved_state, load_context)
318313
obj.process = load_context.process
319314

320315
obj.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN]))
321316
if obj.COMMAND in saved_state:
322-
# FIXME: typing
323317
obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore
318+
324319
return obj
325320

326321
def interrupt(self, reason: Any) -> None:
@@ -450,9 +445,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
450445
451446
"""
452447
load_context = ensure_object_loader(load_context, saved_state)
453-
obj = cls.__new__(cls)
454-
auto_load(obj, saved_state, load_context)
455-
448+
obj = auto_load(cls, saved_state, load_context)
456449
obj.process = load_context.process
457450

458451
callback_name = saved_state.get(obj.DONE_CALLBACK, None)
@@ -556,8 +549,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
556549
557550
"""
558551
load_context = ensure_object_loader(load_context, saved_state)
559-
obj = cls.__new__(cls)
560-
auto_load(obj, saved_state, load_context)
552+
obj = auto_load(cls, saved_state, load_context)
561553

562554
obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
563555
if _HAS_TBLIB:
@@ -616,8 +608,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
616608
617609
"""
618610
load_context = ensure_object_loader(load_context, saved_state)
619-
obj = cls.__new__(cls)
620-
auto_load(obj, saved_state, load_context)
611+
obj = auto_load(cls, saved_state, load_context)
621612
return obj
622613

623614
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
@@ -665,8 +656,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
665656
666657
"""
667658
load_context = ensure_object_loader(load_context, saved_state)
668-
obj = cls.__new__(cls)
669-
auto_load(obj, saved_state, load_context)
659+
obj = auto_load(cls, saved_state, load_context)
670660
return obj
671661

672662
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:

0 commit comments

Comments
 (0)