Skip to content

Commit 34d9b01

Browse files
committed
Make auto_load symmetry with auto_save and state/state_label distinguish
1 parent 66110f6 commit 34d9b01

File tree

11 files changed

+153
-145
lines changed

11 files changed

+153
-145
lines changed

src/plumpy/base/state_machine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
266266
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)
267267

268268
@property
269-
def state(self) -> Any:
269+
def state(self) -> State | None:
270+
if self._state is None:
271+
return None
272+
return self._state
273+
274+
@property
275+
def state_label(self) -> Any:
270276
if self._state is None:
271277
return None
272278
return self._state.LABEL
@@ -312,7 +318,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
312318
if new_state is None:
313319
return None
314320

315-
initial_state_label = self._state.LABEL if self._state is not None else None
321+
initial_state_label = self.state_label
316322
label = None
317323
try:
318324
self._transitioning = True

src/plumpy/event_helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
4545
4646
"""
4747
load_context = ensure_object_loader(load_context, saved_state)
48-
obj = cls.__new__(cls)
49-
auto_load(obj, saved_state, load_context)
48+
obj = auto_load(cls, saved_state, load_context)
5049
return obj
5150

5251
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:

src/plumpy/persistence.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
List,
2121
Optional,
2222
Protocol,
23+
TypeVar,
2324
cast,
2425
runtime_checkable,
2526
)
@@ -535,6 +536,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
535536
value = value.__name__
536537
elif isinstance(value, Savable) and not isinstance(value, type):
537538
# persist for a savable obj, call `save` method of obj.
539+
# the rhs branch is for when value is a Savable class, it is true runtime check
540+
# of lhs condition.
538541
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
539542
value = value.save()
540543
else:
@@ -544,11 +547,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
544547
return out_state
545548

546549

547-
def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
550+
def load_auto_persist_params(
551+
obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None
552+
) -> None:
548553
for member in obj._auto_persist:
549554
setattr(obj, member, _get_value(obj, saved_state, member, load_context))
550555

551556

557+
T = TypeVar('T', bound=Savable)
558+
559+
560+
def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T:
561+
obj = cls.__new__(cls)
562+
563+
if isinstance(obj, SavableWithAutoPersist):
564+
load_auto_persist_params(obj, saved_state, load_context)
565+
566+
return obj
567+
568+
552569
def _get_value(
553570
obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None
554571
) -> MethodType | Savable:

src/plumpy/process_states.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import yaml
2323
from yaml.loader import Loader
2424

25-
from plumpy.persistence import ensure_object_loader
2625
from plumpy.process_comms import KillMessage, MessageType
2726

2827
try:
@@ -41,6 +40,7 @@
4140
auto_load,
4241
auto_persist,
4342
auto_save,
43+
ensure_object_loader,
4444
)
4545
from .utils import SAVED_STATE_TYPE
4646

@@ -98,8 +98,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
9898
:return: The recreated instance
9999
100100
"""
101-
obj = cls.__new__(cls)
102-
auto_load(obj, saved_state, load_context)
101+
load_context = ensure_object_loader(load_context, saved_state)
102+
obj = auto_load(cls, saved_state, load_context)
103103
return obj
104104

105105
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
@@ -171,15 +171,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
171171
172172
"""
173173
load_context = ensure_object_loader(load_context, saved_state)
174-
obj = cls.__new__(cls)
175-
auto_load(obj, saved_state, load_context)
174+
obj = auto_load(cls, saved_state, load_context)
176175

177-
obj.state_machine = load_context.process
178176
try:
179177
obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN])
180178
except ValueError:
181-
process = load_context.process
182-
obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN])
179+
if load_context is not None:
180+
obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN])
181+
else:
182+
raise
183183
return obj
184184

185185

@@ -235,12 +235,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
235235
236236
"""
237237
load_context = ensure_object_loader(load_context, saved_state)
238-
obj = cls.__new__(cls)
239-
240-
auto_load(obj, saved_state, load_context)
241-
238+
obj = auto_load(cls, saved_state, load_context)
242239
obj.process = load_context.process
243-
244240
obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
245241

246242
return obj
@@ -306,15 +302,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
306302
307303
"""
308304
load_context = ensure_object_loader(load_context, saved_state)
309-
obj = cls.__new__(cls)
310-
auto_load(obj, saved_state, load_context)
311-
305+
obj = auto_load(cls, saved_state, load_context)
312306
obj.process = load_context.process
313-
314307
obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
315308
if obj.COMMAND in saved_state:
316-
# FIXME: typing
317309
obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore
310+
318311
return obj
319312

320313
def interrupt(self, reason: Any) -> None:
@@ -444,9 +437,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
444437
445438
"""
446439
load_context = ensure_object_loader(load_context, saved_state)
447-
obj = cls.__new__(cls)
448-
auto_load(obj, saved_state, load_context)
449-
440+
obj = auto_load(cls, saved_state, load_context)
450441
obj.process = load_context.process
451442

452443
callback_name = saved_state.get(obj.DONE_CALLBACK, None)
@@ -550,8 +541,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
550541
551542
"""
552543
load_context = ensure_object_loader(load_context, saved_state)
553-
obj = cls.__new__(cls)
554-
auto_load(obj, saved_state, load_context)
544+
obj = auto_load(cls, saved_state, load_context)
555545

556546
obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
557547
if _HAS_TBLIB:
@@ -610,8 +600,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
610600
611601
"""
612602
load_context = ensure_object_loader(load_context, saved_state)
613-
obj = cls.__new__(cls)
614-
auto_load(obj, saved_state, load_context)
603+
obj = auto_load(cls, saved_state, load_context)
615604
return obj
616605

617606
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
@@ -659,8 +648,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
659648
660649
"""
661650
load_context = ensure_object_loader(load_context, saved_state)
662-
obj = cls.__new__(cls)
663-
auto_load(obj, saved_state, load_context)
651+
obj = auto_load(cls, saved_state, load_context)
664652
return obj
665653

666654
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:

0 commit comments

Comments
 (0)