Skip to content

Commit 0469cb0

Browse files
authored
chore: Improved factory type hints (#399)
1 parent 50cc5ef commit 0469cb0

File tree

2 files changed

+53
-30
lines changed

2 files changed

+53
-30
lines changed

statemachine/factory.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import TYPE_CHECKING
22
from typing import Any
33
from typing import Dict
4+
from typing import List
45
from typing import Tuple
56
from uuid import uuid4
67

@@ -31,7 +32,13 @@ def __init__(cls, name: str, bases: Tuple[type], attrs: Dict[str, Any]):
3132
cls.add_inherited(bases)
3233
cls.add_from_attributes(attrs)
3334

34-
cls._set_special_states()
35+
try:
36+
cls.initial_state: State = next(s for s in cls.states if s.initial)
37+
except StopIteration:
38+
cls.initial_state = None # Abstract SM still don't have states
39+
40+
cls.final_states: List[State] = [state for state in cls.states if state.final]
41+
3542
cls._check()
3643

3744
if TYPE_CHECKING:
@@ -40,35 +47,6 @@ def __init__(cls, name: str, bases: Tuple[type], attrs: Dict[str, Any]):
4047
def __getattr__(self, attribute: str) -> Any:
4148
...
4249

43-
def _set_special_states(cls):
44-
if not cls.states:
45-
return
46-
initials = [s for s in cls.states if s.initial]
47-
if len(initials) != 1:
48-
raise InvalidDefinition(
49-
_(
50-
"There should be one and only one initial state. "
51-
"Your currently have these: {!r}"
52-
).format([s.id for s in initials])
53-
)
54-
cls.initial_state = initials[0]
55-
cls.final_states = [state for state in cls.states if state.final]
56-
57-
def _disconnected_states(cls, starting_state):
58-
visitable_states = set(visit_connected_states(starting_state))
59-
return set(cls.states) - visitable_states
60-
61-
def _check_disconnected_state(cls):
62-
disconnected_states = cls._disconnected_states(cls.initial_state)
63-
if disconnected_states:
64-
raise InvalidDefinition(
65-
_(
66-
"There are unreachable states. "
67-
"The statemachine graph should have a single component. "
68-
"Disconnected states: {}"
69-
).format([s.id for s in disconnected_states])
70-
)
71-
7250
def _check(cls):
7351
has_states = bool(cls.states)
7452
has_events = bool(cls._events)
@@ -85,8 +63,21 @@ def _check(cls):
8563
if not has_events:
8664
raise InvalidDefinition(_("There are no events."))
8765

66+
cls._check_initial_state()
67+
cls._check_final_states()
8868
cls._check_disconnected_state()
8969

70+
def _check_initial_state(cls):
71+
initials = [s for s in cls.states if s.initial]
72+
if len(initials) != 1:
73+
raise InvalidDefinition(
74+
_(
75+
"There should be one and only one initial state. "
76+
"Your currently have these: {!r}"
77+
).format([s.id for s in initials])
78+
)
79+
80+
def _check_final_states(cls):
9081
final_state_with_invalid_transitions = [
9182
state for state in cls.final_states if state.transitions
9283
]
@@ -98,6 +89,21 @@ def _check(cls):
9889
).format([s.id for s in final_state_with_invalid_transitions])
9990
)
10091

92+
def _disconnected_states(cls, starting_state):
93+
visitable_states = set(visit_connected_states(starting_state))
94+
return set(cls.states) - visitable_states
95+
96+
def _check_disconnected_state(cls):
97+
disconnected_states = cls._disconnected_states(cls.initial_state)
98+
if disconnected_states:
99+
raise InvalidDefinition(
100+
_(
101+
"There are unreachable states. "
102+
"The statemachine graph should have a single component. "
103+
"Disconnected states: {}"
104+
).format([s.id for s in disconnected_states])
105+
)
106+
101107
def add_inherited(cls, bases):
102108
for base in bases:
103109
for state in getattr(base, "states", []):

tests/test_statemachine.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ class CampaignMachine(StateMachine):
5050
deliver = producing.to(closed)
5151

5252

53+
def test_machine_should_activate_initial_state():
54+
class CampaignMachine(StateMachine):
55+
"A workflow machine"
56+
producing = State()
57+
closed = State()
58+
draft = State(initial=True)
59+
60+
add_job = draft.to(draft) | producing.to(producing)
61+
produce = draft.to(producing)
62+
deliver = producing.to(closed)
63+
64+
sm = CampaignMachine()
65+
66+
assert sm.current_state == sm.draft
67+
assert sm.current_state.is_active
68+
69+
5370
def test_machine_should_not_allow_transitions_from_final_state():
5471
with pytest.raises(exceptions.InvalidDefinition):
5572

0 commit comments

Comments
 (0)