Skip to content

Commit e16e0ef

Browse files
authored
feat: Supporting pickle protocol (#500)
1 parent 9b55852 commit e16e0ef

File tree

3 files changed

+204
-118
lines changed

3 files changed

+204
-118
lines changed

statemachine/statemachine.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ def __deepcopy__(self, memo):
147147
cp.add_listener(*cp._listeners.keys())
148148
return cp
149149

150+
def __getstate__(self):
151+
state = self.__dict__.copy()
152+
state["_rtc"] = self._engine._rtc
153+
del state["_callbacks"]
154+
del state["_states_for_instance"]
155+
del state["_engine"]
156+
return state
157+
158+
def __setstate__(self, state):
159+
listeners = state.pop("_listeners")
160+
rtc = state.pop("_rtc")
161+
self.__dict__.update(state)
162+
self._callbacks = CallbacksRegistry()
163+
self._states_for_instance: Dict[State, State] = {}
164+
165+
self._listeners: Dict[Any, Any] = {}
166+
167+
self._register_callbacks([])
168+
self.add_listener(*listeners.keys())
169+
self._engine = self._get_engine(rtc)
170+
150171
def _get_initial_state(self):
151172
initial_state_value = self.start_value if self.start_value else self.initial_state.value
152173
try:

tests/test_copy.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import logging
2+
import pickle
3+
from copy import deepcopy
4+
from enum import Enum
5+
from enum import auto
6+
7+
import pytest
8+
9+
from statemachine import State
10+
from statemachine import StateMachine
11+
from statemachine.exceptions import TransitionNotAllowed
12+
from statemachine.states import States
13+
14+
logger = logging.getLogger(__name__)
15+
DEBUG = logging.DEBUG
16+
17+
18+
def copy_pickle(obj):
19+
return pickle.loads(pickle.dumps(obj))
20+
21+
22+
@pytest.fixture(params=[deepcopy, copy_pickle], ids=["deepcopy", "pickle"])
23+
def copy_method(request):
24+
return request.param
25+
26+
27+
class GameStates(str, Enum):
28+
GAME_START = auto()
29+
GAME_PLAYING = auto()
30+
TURN_END = auto()
31+
GAME_END = auto()
32+
33+
34+
class GameStateMachine(StateMachine):
35+
s = States.from_enum(GameStates, initial=GameStates.GAME_START)
36+
37+
play = s.GAME_START.to(s.GAME_PLAYING)
38+
stop = s.GAME_PLAYING.to(s.TURN_END)
39+
end_game = s.TURN_END.to(s.GAME_END)
40+
41+
@end_game.cond
42+
def game_is_over(self) -> bool:
43+
return True
44+
45+
advance_round = end_game | s.TURN_END.to(s.GAME_END)
46+
47+
48+
class MyStateMachine(StateMachine):
49+
created = State(initial=True)
50+
started = State()
51+
52+
start = created.to(started)
53+
54+
def __init__(self):
55+
super().__init__()
56+
self.custom = 1
57+
self.value = [1, 2, 3]
58+
59+
60+
class MySM(StateMachine):
61+
draft = State("Draft", initial=True, value="draft")
62+
published = State("Published", value="published", final=True)
63+
64+
publish = draft.to(published, cond="let_me_be_visible")
65+
66+
def on_transition(self, event: str):
67+
logger.debug(f"{self.__class__.__name__} recorded {event} transition")
68+
69+
def let_me_be_visible(self):
70+
logger.debug(f"{type(self).__name__} let_me_be_visible: True")
71+
return True
72+
73+
74+
class MyModel:
75+
def __init__(self, name: str) -> None:
76+
self.name = name
77+
self.let_me_be_visible = False
78+
79+
def __repr__(self) -> str:
80+
return f"{type(self).__name__}@{id(self)}({self.name!r})"
81+
82+
def on_transition(self, event: str):
83+
logger.debug(f"{type(self).__name__}({self.name!r}) recorded {event} transition")
84+
85+
@property
86+
def let_me_be_visible(self):
87+
logger.debug(
88+
f"{type(self).__name__}({self.name!r}) let_me_be_visible: {self._let_me_be_visible}"
89+
)
90+
return self._let_me_be_visible
91+
92+
@let_me_be_visible.setter
93+
def let_me_be_visible(self, value):
94+
self._let_me_be_visible = value
95+
96+
97+
def test_copy(copy_method):
98+
sm = MySM(MyModel("main_model"))
99+
100+
sm2 = copy_method(sm)
101+
102+
with pytest.raises(TransitionNotAllowed):
103+
sm2.send("publish")
104+
105+
106+
def test_copy_with_listeners(caplog, copy_method):
107+
model1 = MyModel("main_model")
108+
109+
sm1 = MySM(model1)
110+
111+
listener_1 = MyModel("observer_1")
112+
listener_2 = MyModel("observer_2")
113+
sm1.add_listener(listener_1)
114+
sm1.add_listener(listener_2)
115+
116+
sm2 = copy_method(sm1)
117+
118+
assert sm1.model is not sm2.model
119+
120+
caplog.set_level(logging.DEBUG, logger="tests")
121+
122+
def assertions(sm, _reference):
123+
caplog.clear()
124+
if not sm._listeners:
125+
pytest.fail("did not found any observer")
126+
127+
for listener in sm._listeners:
128+
listener.let_me_be_visible = False
129+
130+
with pytest.raises(TransitionNotAllowed):
131+
sm.send("publish")
132+
133+
sm.model.let_me_be_visible = True
134+
135+
for listener in sm._listeners:
136+
with pytest.raises(TransitionNotAllowed):
137+
sm.send("publish")
138+
139+
listener.let_me_be_visible = True
140+
141+
sm.send("publish")
142+
143+
assert caplog.record_tuples == [
144+
("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"),
145+
("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: False"),
146+
("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"),
147+
("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"),
148+
("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: False"),
149+
("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"),
150+
("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"),
151+
("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: True"),
152+
("tests.test_copy", DEBUG, "MyModel('observer_2') let_me_be_visible: False"),
153+
("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"),
154+
("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"),
155+
("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: True"),
156+
("tests.test_copy", DEBUG, "MyModel('observer_2') let_me_be_visible: True"),
157+
("tests.test_copy", DEBUG, "MySM recorded publish transition"),
158+
("tests.test_copy", DEBUG, "MyModel('main_model') recorded publish transition"),
159+
("tests.test_copy", DEBUG, "MyModel('observer_1') recorded publish transition"),
160+
("tests.test_copy", DEBUG, "MyModel('observer_2') recorded publish transition"),
161+
]
162+
163+
assertions(sm1, "original")
164+
assertions(sm2, "copy")
165+
166+
167+
def test_copy_with_enum(copy_method):
168+
sm = GameStateMachine()
169+
sm.play()
170+
assert sm.current_state == GameStateMachine.GAME_PLAYING
171+
172+
sm2 = copy_method(sm)
173+
assert sm2.current_state == GameStateMachine.GAME_PLAYING
174+
175+
176+
def test_copy_with_custom_init_and_vars(copy_method):
177+
sm = MyStateMachine()
178+
sm.start()
179+
180+
sm2 = copy_method(sm)
181+
assert sm2.custom == 1
182+
assert sm2.value == [1, 2, 3]
183+
assert sm2.current_state == MyStateMachine.started

tests/test_deepcopy.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

0 commit comments

Comments
 (0)