Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __setstate__(self, state):
self._register_callbacks([])
self.add_listener(*listeners.keys())
self._engine = self._get_engine(rtc)
self._engine.start()

def _get_initial_state(self):
initial_state_value = self.start_value if self.start_value else self.initial_state.value
Expand Down
34 changes: 32 additions & 2 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from enum import auto

import pytest
from statemachine.exceptions import TransitionNotAllowed
from statemachine.states import States

from statemachine import State
from statemachine import StateMachine
from statemachine.exceptions import TransitionNotAllowed
from statemachine.states import States

logger = logging.getLogger(__name__)
DEBUG = logging.DEBUG
Expand Down Expand Up @@ -181,3 +181,33 @@ def test_copy_with_custom_init_and_vars(copy_method):
assert sm2.custom == 1
assert sm2.value == [1, 2, 3]
assert sm2.current_state == MyStateMachine.started


class _AsyncTrafficLightForPickleTest(StateMachine):
"""Defined at module level to be picklable for test_pickle_async_statemachine."""

green = State(initial=True)
yellow = State()
red = State()

cycle = green.to(yellow) | yellow.to(red) | red.to(green)

async def on_enter_state(self, target):
pass


def test_pickle_async_statemachine():
"""Regression test for issue #544: async SM fails after pickle."""
import asyncio

sm = _AsyncTrafficLightForPickleTest()

sm_copy = pickle.loads(pickle.dumps(sm))

async def verify():
await sm_copy.activate_initial_state() # type: ignore[awaitable]
assert sm_copy.current_state == _AsyncTrafficLightForPickleTest.green
await sm_copy.cycle()
assert sm_copy.current_state == _AsyncTrafficLightForPickleTest.yellow

asyncio.run(verify())