Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions django_fsm/management/commands/graph_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ def node_name(field, state) -> str:
return "{}.{}.{}.{}".format(opts.app_label, opts.verbose_name.replace(" ", "_"), field.name, state)


def node_label(field, state) -> str:
if isinstance(state, int):
return str(state)
if isinstance(state, bool) and hasattr(field, "choices"):
return force_str(dict(field.choices).get(state))
return state
def node_label(field, state: str | None) -> str:
if isinstance(state, (int, bool)) and hasattr(field, "choices") and field.choices:
state = dict(field.choices).get(state)
return force_str(state)


def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # noqa: C901, PLR0912
Expand Down
47 changes: 35 additions & 12 deletions tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,23 @@ def on_error(self):
pass


class BlogPostState(models.IntegerChoices):
NEW = 0, "New"
PUBLISHED = 1, "Published"
HIDDEN = 2, "Hidden"
REMOVED = 3, "Removed"
RESTORED = 4, "Restored"
MODERATED = 5, "Moderated"
STOLEN = 6, "Stolen"
FAILED = 7, "Failed"


class BlogPost(models.Model):
"""
Test workflow
"""

state = FSMField(default="new", protected=True)
state = FSMField(choices=BlogPostState.choices, default=BlogPostState.NEW, protected=True)

class Meta:
permissions = [
Expand All @@ -186,41 +197,53 @@ class Meta:
def can_restore(self, user):
return user.is_superuser or user.is_staff

@transition(field=state, source="new", target="published", on_error="failed", permission="testapp.can_publish_post")
@transition(
field=state,
source=BlogPostState.NEW,
target=BlogPostState.PUBLISHED,
on_error=BlogPostState.FAILED,
permission="testapp.can_publish_post",
)
def publish(self):
pass

@transition(field=state, source="published")
@transition(field=state, source=BlogPostState.PUBLISHED)
def notify_all(self):
pass

@transition(
field=state,
source="published",
target="hidden",
on_error="failed",
source=BlogPostState.PUBLISHED,
target=BlogPostState.HIDDEN,
on_error=BlogPostState.FAILED,
)
def hide(self):
pass

@transition(
field=state,
source="new",
target="removed",
on_error="failed",
source=BlogPostState.NEW,
target=BlogPostState.REMOVED,
on_error=BlogPostState.FAILED,
permission=lambda _, u: u.has_perm("testapp.can_remove_post"),
)
def remove(self):
raise Exception(f"No rights to delete {self}")

@transition(field=state, source="new", target="restored", on_error="failed", permission=can_restore)
@transition(
field=state,
source=BlogPostState.NEW,
target=BlogPostState.RESTORED,
on_error=BlogPostState.FAILED,
permission=can_restore,
)
def restore(self):
pass

@transition(field=state, source=["published", "hidden"], target="stolen")
@transition(field=state, source=[BlogPostState.PUBLISHED, BlogPostState.HIDDEN], target=BlogPostState.STOLEN)
def steal(self):
pass

@transition(field=state, source="*", target="moderated")
@transition(field=state, source="*", target=BlogPostState.MODERATED)
def moderate(self):
pass
6 changes: 6 additions & 0 deletions tests/testapp/tests/test_graph_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from django.test import TestCase

from django_fsm.management.commands.graph_transitions import get_graphviz_layouts
from django_fsm.management.commands.graph_transitions import node_label
from tests.testapp.models import BlogPost
from tests.testapp.models import BlogPostState


class GraphTransitionsCommandTest(TestCase):
Expand All @@ -12,6 +15,9 @@ class GraphTransitionsCommandTest(TestCase):
"testapp.FKApplication",
]

def test_node_label(self):
assert node_label(BlogPost.state.field, BlogPostState.PUBLISHED.value) == BlogPostState.PUBLISHED.label

def test_app(self):
call_command("graph_transitions", "testapp")

Expand Down