Skip to content

Commit 6f5e624

Browse files
committed
Improve integration
1 parent 29de16a commit 6f5e624

File tree

10 files changed

+422
-73
lines changed

10 files changed

+422
-73
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,6 @@ test.db
133133

134134
# django fsm command tests
135135
exports/*
136+
137+
# Codex
138+
.codex

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changelog
22
=========
33

4+
Unreleased
5+
~~~~~~~~~~
6+
7+
- Add typing
8+
9+
410
django-fsm-2 4.1.0 2025-11-03
511
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
612

django_fsm/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class Transition:
9797
def __init__(
9898
self,
9999
method: Callable[..., _StateValue],
100-
source: _StateValue | Sequence[_StateValue] | State,
100+
source: _StateValue,
101101
target: _StateValue,
102102
on_error: _StateValue | None,
103103
conditions: list[_Condition] | None,
@@ -415,9 +415,7 @@ def get_all_transitions(self, instance_cls: type[_FSMModel]) -> Generator[Transi
415415
transitions = self.transitions[instance_cls]
416416

417417
for transition in transitions.values():
418-
meta = transition._django_fsm
419-
420-
yield from meta.transitions.values()
418+
yield from transition._django_fsm.transitions.values()
421419

422420
@override
423421
def contribute_to_class(
@@ -725,6 +723,8 @@ def has_transition_perm(bound_method: typing.Any, user: UserWithPermissions) ->
725723

726724

727725
class State:
726+
allowed_states: Sequence[_StateValue]
727+
728728
def get_state(
729729
self,
730730
model: _FSMModel,
@@ -737,8 +737,8 @@ def get_state(
737737

738738

739739
class RETURN_VALUE(State): # noqa: N801
740-
def __init__(self, *allowed_states: Sequence[_StateValue]) -> None:
741-
self.allowed_states = allowed_states if allowed_states else None
740+
def __init__(self, *allowed_states: _StateValue) -> None:
741+
self.allowed_states = allowed_states or []
742742

743743
def get_state(
744744
self,
@@ -762,7 +762,7 @@ def __init__(
762762
states: Sequence[_StateValue] | None = None,
763763
) -> None:
764764
self.func = func
765-
self.allowed_states = states
765+
self.allowed_states = states or []
766766

767767
def get_state(
768768
self,

django_fsm/management/commands/graph_transitions.py

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def one_fsm_fields_data(
3333
model: type[models.Model], field_name: str
3434
) -> tuple[FSMFieldMixin, type[models.Model]]:
3535
field = model._meta.get_field(field_name)
36-
assert isinstance(field, FSMFieldMixin)
36+
if not isinstance(field, FSMFieldMixin):
37+
raise LookupError(f"{field_name} is not an FSMField") # noqa: TRY004
3738
return (field, model)
3839

3940

@@ -61,9 +62,9 @@ def generate_dot( # noqa: C901, PLR0912
6162
for field, model in fields_data:
6263
sources: set[tuple[(str, str)]] = set()
6364
targets: set[tuple[str, str]] = set()
64-
edges: set[tuple[str, str, tuple[tuple[str, ...]]]] = set()
65+
edges: set[tuple[str, str, tuple[tuple[str, str]]]] = set()
6566
any_targets: set[tuple[_StateValue, str]] = set()
66-
any_except_targets: set[tuple[str, str]] = set()
67+
any_except_targets: set[tuple[_StateValue, str]] = set()
6768

6869
# dump nodes and edges
6970
for transition in field.get_all_transitions(model):
@@ -80,6 +81,7 @@ def generate_dot( # noqa: C901, PLR0912
8081
if isinstance(transition.source, GET_STATE | RETURN_VALUE)
8182
else ((transition.source, node_name(field, transition.source)),)
8283
)
84+
8385
for source, source_name in source_name_pair:
8486
if transition.on_error:
8587
on_error_name = node_name(field, transition.on_error)
@@ -92,16 +94,10 @@ def generate_dot( # noqa: C901, PLR0912
9294
elif transition.source == "+":
9395
any_except_targets.add((target, transition.name))
9496
else:
95-
add_transition(
96-
source,
97-
target,
98-
transition.name,
99-
source_name,
100-
field,
101-
sources,
102-
targets,
103-
edges,
104-
)
97+
target_name = node_name(field, target)
98+
sources.add((source_name, node_label(field, source)))
99+
targets.add((target_name, node_label(field, target)))
100+
edges.add((source_name, target_name, (("label", transition.name),)))
105101

106102
targets.update(
107103
{
@@ -134,46 +130,23 @@ def generate_dot( # noqa: C901, PLR0912
134130
final_states = targets - sources
135131
for name, label in final_states:
136132
subgraph.node(name, label=label, shape="doublecircle")
133+
137134
for name, label in (sources | targets) - final_states:
138135
subgraph.node(name, label=label, shape="circle")
139136
# Adding initial state notation
140137
if field.default and label == field.default:
141138
initial_name = node_name(field, "_initial")
142139
subgraph.node(name=initial_name, label="", shape="point")
143-
subgraph.edge(initial_name, name)
140+
subgraph.edge(tail_name=initial_name, head_name=name)
141+
144142
for source_name, target_name, attrs in edges:
145-
subgraph.edge(source_name, target_name, **dict(attrs))
143+
subgraph.edge(tail_name=source_name, head_name=target_name, **dict(attrs))
146144

147145
result.subgraph(subgraph)
148146

149147
return result
150148

151149

152-
def add_transition(
153-
transition_source: _StateValue,
154-
transition_target: _StateValue,
155-
transition_name: str,
156-
source_name: str,
157-
field: FSMFieldMixin,
158-
sources: set[tuple[str, str]],
159-
targets: set[tuple[str, str]],
160-
edges: set[tuple[str, str, tuple[tuple[str, str], ...]]],
161-
) -> None:
162-
target_name = node_name(field, transition_target)
163-
sources.add((source_name, node_label(field, transition_source)))
164-
targets.add((target_name, node_label(field, transition_target)))
165-
edges.add((source_name, target_name, (("label", transition_name),)))
166-
167-
168-
def get_graphviz_layouts() -> set[str] | Sequence[str]:
169-
try:
170-
import graphviz
171-
except ModuleNotFoundError:
172-
return {"sfdp", "circo", "twopi", "dot", "neato", "fdp", "osage", "patchwork"}
173-
else:
174-
return graphviz.ENGINES # type: ignore[no-any-return]
175-
176-
177150
class Command(BaseCommand):
178151
help = "Creates a GraphViz dot file with transitions for selected fields"
179152

@@ -192,7 +165,7 @@ def add_arguments(self, parser: ArgumentParser) -> None:
192165
action="store",
193166
dest="layout",
194167
default="dot",
195-
help=f"Layout to be used by GraphViz for visualization: {get_graphviz_layouts()}.",
168+
help=f"Layout to be used by GraphViz for visualization: {graphviz.ENGINES}.",
196169
)
197170
parser.add_argument(
198171
"--exclude",
@@ -204,13 +177,6 @@ def add_arguments(self, parser: ArgumentParser) -> None:
204177
)
205178
parser.add_argument("args", nargs="*", help=("[appname[.model[.field]]]"))
206179

207-
def render_output(self, graph: graphviz.Digraph, **options: typing.Any) -> None:
208-
filename, graph_format = options["outputfile"].rsplit(".", 1)
209-
210-
graph.engine = options["layout"]
211-
graph.format = graph_format
212-
graph.render(filename)
213-
214180
def handle(self, *args: str, **options: typing.Any) -> None:
215181
fields_data: list[tuple[FSMFieldMixin, type[models.Model]]] = []
216182
if args:
@@ -232,8 +198,11 @@ def handle(self, *args: str, **options: typing.Any) -> None:
232198

233199
dotdata = generate_dot(fields_data, ignore_transitions=options["exclude"].split(","))
234200

235-
outputfile = options["outputfile"]
236-
if outputfile:
237-
self.render_output(dotdata, **options)
201+
if outputfile := options["outputfile"]:
202+
filename, graph_format = outputfile.rsplit(".", 1)
203+
204+
dotdata.engine = options["layout"]
205+
dotdata.format = graph_format
206+
dotdata.render(filename)
238207
else:
239208
self.stdout.write(str(dotdata))

pyproject.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ classifiers = [
2525
"Programming Language :: Python :: 3.12",
2626
"Programming Language :: Python :: 3.13",
2727
"Programming Language :: Python :: 3.14",
28+
"Typing :: Typed",
2829
"Topic :: Software Development :: Libraries :: Python Modules",
2930
]
3031
dependencies = [
3132
"django>=4.2",
32-
"typing-extensions>=4.13.2",
33+
"typing-extensions; python_version < '3.12'",
3334
]
3435

3536
[project.urls]
@@ -46,7 +47,7 @@ graphviz = [
4647
dev = [
4748
"coverage",
4849
"django-guardian",
49-
"graphviz",
50+
"django-stubs-ext",
5051
"pre-commit",
5152
"pytest",
5253
"pytest-cov",
@@ -57,6 +58,7 @@ dev = [
5758
[tool.uv]
5859
default-groups = [
5960
"dev",
61+
"graphviz"
6062
]
6163

6264
[tool.hatch.build.targets.sdist]
@@ -156,9 +158,3 @@ module = [
156158
"django_fsm.tests.*"
157159
]
158160
disallow_untyped_defs = false
159-
160-
[[tool.mypy.overrides]]
161-
module = [
162-
"django_fsm.management.commands.graph_transitions.*",
163-
]
164-
ignore_errors = true

tests/settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212

1313
from __future__ import annotations
1414

15+
import typing
1516
from pathlib import Path
1617

18+
if typing.TYPE_CHECKING:
19+
import django_stubs_ext
20+
21+
django_stubs_ext.monkeypatch()
22+
1723
# Build paths inside the project like this: BASE_DIR / 'subdir'.
1824
BASE_DIR = Path(__file__).resolve().parent.parent
1925

tests/testapp/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Application(models.Model):
2222

2323
state = FSMField(default="new")
2424

25-
@transition(field=state, source="new", target="published")
25+
@transition(field=state, source="new", target="published", on_error="failed")
2626
def standard(self) -> None:
2727
pass
2828

0 commit comments

Comments
 (0)