diff --git a/django_fsm/management/commands/graph_transitions.py b/django_fsm/management/commands/graph_transitions.py index 467cc80..d1cd6e6 100644 --- a/django_fsm/management/commands/graph_transitions.py +++ b/django_fsm/management/commands/graph_transitions.py @@ -35,27 +35,28 @@ def generate_dot(fields_data): # noqa: C901, PLR0912 # dump nodes and edges for transition in field.get_all_transitions(model): - if transition.source == "*": - any_targets.add((transition.target, transition.name)) - elif transition.source == "+": - any_except_targets.add((transition.target, transition.name)) - else: - _targets = ( - (state for state in transition.target.allowed_states) - if isinstance(transition.target, (GET_STATE, RETURN_VALUE)) - else (transition.target,) - ) - source_name_pair = ( - ((state, node_name(field, state)) for state in transition.source.allowed_states) - if isinstance(transition.source, (GET_STATE, RETURN_VALUE)) - else ((transition.source, node_name(field, transition.source)),) - ) - for source, source_name in source_name_pair: - if transition.on_error: - on_error_name = node_name(field, transition.on_error) - targets.add((on_error_name, node_label(field, transition.on_error))) - edges.add((source_name, on_error_name, (("style", "dotted"),))) - for target in _targets: + _targets = list( + (state for state in transition.target.allowed_states) + if isinstance(transition.target, (GET_STATE, RETURN_VALUE)) + else (transition.target,) + ) + source_name_pair = ( + ((state, node_name(field, state)) for state in transition.source.allowed_states) + if isinstance(transition.source, (GET_STATE, RETURN_VALUE)) + else ((transition.source, node_name(field, transition.source)),) + ) + for source, source_name in source_name_pair: + if transition.on_error: + on_error_name = node_name(field, transition.on_error) + targets.add((on_error_name, node_label(field, transition.on_error))) + edges.add((source_name, on_error_name, (("style", "dotted"),))) + + for target in _targets: + if transition.source == "*": + any_targets.add((target, transition.name)) + elif transition.source == "+": + any_except_targets.add((target, transition.name)) + else: add_transition(source, target, transition.name, source_name, field, sources, targets, edges) targets.update( diff --git a/tests/testapp/models.py b/tests/testapp/models.py index c630d98..c012e90 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -2,6 +2,8 @@ from django.db import models +from django_fsm import GET_STATE +from django_fsm import RETURN_VALUE from django_fsm import FSMField from django_fsm import FSMKeyField from django_fsm import transition @@ -15,28 +17,69 @@ class Application(models.Model): state = FSMField(default="new") - @transition(field=state, source="new", target="draft") - def draft(self): + @transition(field=state, source="new", target="published") + def standard(self): pass - @transition(field=state, source=["new", "draft"], target="dept") - def submitted(self): + @transition(field=state, source="published") + def no_target(self): + pass + + @transition(field=state, source="*", target="blocked") + def any_source(self): + pass + + @transition(field=state, source="+", target="hidden") + def any_source_except_target(self): pass - @transition(field=state, source="dept", target="dean") - def dept_approved(self): + @transition( + field=state, + source="new", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state(self, *, allowed: bool): pass - @transition(field=state, source="dept", target="new") - def dept_rejected(self): + @transition( + field=state, + source="*", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state_any_source(self, *, allowed: bool): pass - @transition(field=state, source="dean", target="done") - def dean_approved(self): + @transition( + field=state, + source="+", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state_any_source_except_target(self, *, allowed: bool): pass - @transition(field=state, source="dean", target="dept") - def dean_rejected(self): + @transition(field=state, source="new", target=RETURN_VALUE("moderated", "blocked")) + def return_value(self): + return "published" + + @transition(field=state, source="*", target=RETURN_VALUE("moderated", "blocked")) + def return_value_any_source(self): + return "published" + + @transition(field=state, source="+", target=RETURN_VALUE("moderated", "blocked")) + def return_value_any_source_except_target(self): + return "published" + + @transition(field=state, source="new", target="published", on_error="failed") + def on_error(self): pass @@ -61,28 +104,69 @@ class FKApplication(models.Model): state = FSMKeyField(DbState, default="new", on_delete=models.CASCADE) - @transition(field=state, source="new", target="draft") - def draft(self): + @transition(field=state, source="new", target="published") + def standard(self): pass - @transition(field=state, source=["new", "draft"], target="dept") - def submitted(self): + @transition(field=state, source="published") + def no_target(self): + pass + + @transition(field=state, source="*", target="blocked") + def any_source(self): + pass + + @transition(field=state, source="+", target="hidden") + def any_source_except_target(self): pass - @transition(field=state, source="dept", target="dean") - def dept_approved(self): + @transition( + field=state, + source="new", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state(self, *, allowed: bool): pass - @transition(field=state, source="dept", target="new") - def dept_rejected(self): + @transition( + field=state, + source="*", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state_any_source(self, *, allowed: bool): pass - @transition(field=state, source="dean", target="done") - def dean_approved(self): + @transition( + field=state, + source="+", + target=GET_STATE( + lambda _, allowed: "published" if allowed else "rejected", + states=["published", "rejected"], + ), + ) + def get_state_any_source_except_target(self, *, allowed: bool): pass - @transition(field=state, source="dean", target="dept") - def dean_rejected(self): + @transition(field=state, source="new", target=RETURN_VALUE("moderated", "blocked")) + def return_value(self): + return "published" + + @transition(field=state, source="*", target=RETURN_VALUE("moderated", "blocked")) + def return_value_any_source(self): + return "published" + + @transition(field=state, source="+", target=RETURN_VALUE("moderated", "blocked")) + def return_value_any_source_except_target(self): + return "published" + + @transition(field=state, source="new", target="published", on_error="failed") + def on_error(self): pass diff --git a/tests/testapp/tests/test_graph_transitions.py b/tests/testapp/tests/test_graph_transitions.py index e814a11..d8f27ad 100644 --- a/tests/testapp/tests/test_graph_transitions.py +++ b/tests/testapp/tests/test_graph_transitions.py @@ -1,49 +1,22 @@ from __future__ import annotations from django.core.management import call_command -from django.db import models from django.test import TestCase -from django_fsm import FSMField -from django_fsm import transition +from django_fsm.management.commands.graph_transitions import get_graphviz_layouts -class VisualBlogPost(models.Model): - state = FSMField(default="new") - - @transition(field=state, source="new", target="published") - def publish(self): - pass - - @transition(source="published", field=state) - def notify_all(self): - pass - - @transition(source="published", target="hidden", field=state) - def hide(self): - pass - - @transition(source="new", target="removed", field=state) - def remove(self): - raise Exception("Upss") - - @transition(source=["published", "hidden"], target="stolen", field=state) - def steal(self): - pass - - @transition(source="*", target="moderated", field=state) - def moderate(self): - pass - - @transition(source="+", target="blocked", field=state) - def block(self): - pass +class GraphTransitionsCommandTest(TestCase): + def test_dummy(self): + call_command("graph_transitions", "testapp.Application") - @transition(source="*", target="", field=state) - def empty(self): - pass + def test_layouts(self): + for layout in get_graphviz_layouts(): + call_command("graph_transitions", "-l", layout, "testapp.Application") + def test_fk_dummy(self): + call_command("graph_transitions", "testapp.FKApplication") -class GraphTransitionsCommandTest(TestCase): - def test_dummy(self): - call_command("graph_transitions", "testapp.VisualBlogPost") + def test_fk_layouts(self): + for layout in get_graphviz_layouts(): + call_command("graph_transitions", "-l", layout, "testapp.FKApplication") diff --git a/tests/testapp/tests/test_transition_all_except_target.py b/tests/testapp/tests/test_transition_all_except_target.py index be36204..63baedb 100644 --- a/tests/testapp/tests/test_transition_all_except_target.py +++ b/tests/testapp/tests/test_transition_all_except_target.py @@ -8,7 +8,7 @@ from django_fsm import transition -class TestExceptTargetTransitionShortcut(models.Model): +class ExceptTargetTransition(models.Model): state = FSMField(default="new") @transition(field=state, source="new", target="published") @@ -22,7 +22,7 @@ def remove(self): class Test(TestCase): def setUp(self): - self.model = TestExceptTargetTransitionShortcut() + self.model = ExceptTargetTransition() def test_usecase(self): assert self.model.state == "new"