diff --git a/django_fsm/management/commands/graph_transitions.py b/django_fsm/management/commands/graph_transitions.py index ef8a5e0..8f50913 100644 --- a/django_fsm/management/commands/graph_transitions.py +++ b/django_fsm/management/commands/graph_transitions.py @@ -21,12 +21,10 @@ def node_name(field, state) -> str: return "{}.{}.{}.{}".format(opts.app_label, opts.verbose_name.replace(" ", "_"), field.name, state) -def node_label(field, state) -> str: - if isinstance(state, int): - return str(state) - if isinstance(state, bool) and hasattr(field, "choices"): - return force_str(dict(field.choices).get(state)) - return state +def node_label(field, state: str | None) -> str: + if isinstance(state, (int, bool)) and hasattr(field, "choices") and field.choices: + state = dict(field.choices).get(state) + return force_str(state) def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # noqa: C901, PLR0912 diff --git a/tests/testapp/models.py b/tests/testapp/models.py index c012e90..0d5d216 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -170,12 +170,23 @@ def on_error(self): pass +class BlogPostState(models.IntegerChoices): + NEW = 0, "New" + PUBLISHED = 1, "Published" + HIDDEN = 2, "Hidden" + REMOVED = 3, "Removed" + RESTORED = 4, "Restored" + MODERATED = 5, "Moderated" + STOLEN = 6, "Stolen" + FAILED = 7, "Failed" + + class BlogPost(models.Model): """ Test workflow """ - state = FSMField(default="new", protected=True) + state = FSMField(choices=BlogPostState.choices, default=BlogPostState.NEW, protected=True) class Meta: permissions = [ @@ -186,41 +197,53 @@ class Meta: def can_restore(self, user): return user.is_superuser or user.is_staff - @transition(field=state, source="new", target="published", on_error="failed", permission="testapp.can_publish_post") + @transition( + field=state, + source=BlogPostState.NEW, + target=BlogPostState.PUBLISHED, + on_error=BlogPostState.FAILED, + permission="testapp.can_publish_post", + ) def publish(self): pass - @transition(field=state, source="published") + @transition(field=state, source=BlogPostState.PUBLISHED) def notify_all(self): pass @transition( field=state, - source="published", - target="hidden", - on_error="failed", + source=BlogPostState.PUBLISHED, + target=BlogPostState.HIDDEN, + on_error=BlogPostState.FAILED, ) def hide(self): pass @transition( field=state, - source="new", - target="removed", - on_error="failed", + source=BlogPostState.NEW, + target=BlogPostState.REMOVED, + on_error=BlogPostState.FAILED, permission=lambda _, u: u.has_perm("testapp.can_remove_post"), ) def remove(self): raise Exception(f"No rights to delete {self}") - @transition(field=state, source="new", target="restored", on_error="failed", permission=can_restore) + @transition( + field=state, + source=BlogPostState.NEW, + target=BlogPostState.RESTORED, + on_error=BlogPostState.FAILED, + permission=can_restore, + ) def restore(self): pass - @transition(field=state, source=["published", "hidden"], target="stolen") + @transition(field=state, source=[BlogPostState.PUBLISHED, BlogPostState.HIDDEN], target=BlogPostState.STOLEN) def steal(self): pass - @transition(field=state, source="*", target="moderated") + @transition(field=state, source="*", target=BlogPostState.MODERATED) def moderate(self): pass diff --git a/tests/testapp/tests/test_graph_transitions.py b/tests/testapp/tests/test_graph_transitions.py index 9c47b12..fe67782 100644 --- a/tests/testapp/tests/test_graph_transitions.py +++ b/tests/testapp/tests/test_graph_transitions.py @@ -4,6 +4,9 @@ from django.test import TestCase from django_fsm.management.commands.graph_transitions import get_graphviz_layouts +from django_fsm.management.commands.graph_transitions import node_label +from tests.testapp.models import BlogPost +from tests.testapp.models import BlogPostState class GraphTransitionsCommandTest(TestCase): @@ -12,6 +15,9 @@ class GraphTransitionsCommandTest(TestCase): "testapp.FKApplication", ] + def test_node_label(self): + assert node_label(BlogPost.state.field, BlogPostState.PUBLISHED.value) == BlogPostState.PUBLISHED.label + def test_app(self): call_command("graph_transitions", "testapp")