@@ -33,7 +33,8 @@ def one_fsm_fields_data(
3333 model : type [models .Model ], field_name : str
3434) -> tuple [FSMFieldMixin , type [models .Model ]]:
3535 field = model ._meta .get_field (field_name )
36- assert isinstance (field , FSMFieldMixin )
36+ if not isinstance (field , FSMFieldMixin ):
37+ raise LookupError (f"{ field_name } is not an FSMField" ) # noqa: TRY004
3738 return (field , model )
3839
3940
@@ -61,9 +62,9 @@ def generate_dot( # noqa: C901, PLR0912
6162 for field , model in fields_data :
6263 sources : set [tuple [(str , str )]] = set ()
6364 targets : set [tuple [str , str ]] = set ()
64- edges : set [tuple [str , str , tuple [tuple [str , ... ]]]] = set ()
65+ edges : set [tuple [str , str , tuple [tuple [str , str ]]]] = set ()
6566 any_targets : set [tuple [_StateValue , str ]] = set ()
66- any_except_targets : set [tuple [str , str ]] = set ()
67+ any_except_targets : set [tuple [_StateValue , str ]] = set ()
6768
6869 # dump nodes and edges
6970 for transition in field .get_all_transitions (model ):
@@ -80,6 +81,7 @@ def generate_dot( # noqa: C901, PLR0912
8081 if isinstance (transition .source , GET_STATE | RETURN_VALUE )
8182 else ((transition .source , node_name (field , transition .source )),)
8283 )
84+
8385 for source , source_name in source_name_pair :
8486 if transition .on_error :
8587 on_error_name = node_name (field , transition .on_error )
@@ -92,16 +94,10 @@ def generate_dot( # noqa: C901, PLR0912
9294 elif transition .source == "+" :
9395 any_except_targets .add ((target , transition .name ))
9496 else :
95- add_transition (
96- source ,
97- target ,
98- transition .name ,
99- source_name ,
100- field ,
101- sources ,
102- targets ,
103- edges ,
104- )
97+ target_name = node_name (field , target )
98+ sources .add ((source_name , node_label (field , source )))
99+ targets .add ((target_name , node_label (field , target )))
100+ edges .add ((source_name , target_name , (("label" , transition .name ),)))
105101
106102 targets .update (
107103 {
@@ -134,46 +130,23 @@ def generate_dot( # noqa: C901, PLR0912
134130 final_states = targets - sources
135131 for name , label in final_states :
136132 subgraph .node (name , label = label , shape = "doublecircle" )
133+
137134 for name , label in (sources | targets ) - final_states :
138135 subgraph .node (name , label = label , shape = "circle" )
139136 # Adding initial state notation
140137 if field .default and label == field .default :
141138 initial_name = node_name (field , "_initial" )
142139 subgraph .node (name = initial_name , label = "" , shape = "point" )
143- subgraph .edge (initial_name , name )
140+ subgraph .edge (tail_name = initial_name , head_name = name )
141+
144142 for source_name , target_name , attrs in edges :
145- subgraph .edge (source_name , target_name , ** dict (attrs ))
143+ subgraph .edge (tail_name = source_name , head_name = target_name , ** dict (attrs ))
146144
147145 result .subgraph (subgraph )
148146
149147 return result
150148
151149
152- def add_transition (
153- transition_source : _StateValue ,
154- transition_target : _StateValue ,
155- transition_name : str ,
156- source_name : str ,
157- field : FSMFieldMixin ,
158- sources : set [tuple [str , str ]],
159- targets : set [tuple [str , str ]],
160- edges : set [tuple [str , str , tuple [tuple [str , str ], ...]]],
161- ) -> None :
162- target_name = node_name (field , transition_target )
163- sources .add ((source_name , node_label (field , transition_source )))
164- targets .add ((target_name , node_label (field , transition_target )))
165- edges .add ((source_name , target_name , (("label" , transition_name ),)))
166-
167-
168- def get_graphviz_layouts () -> set [str ] | Sequence [str ]:
169- try :
170- import graphviz
171- except ModuleNotFoundError :
172- return {"sfdp" , "circo" , "twopi" , "dot" , "neato" , "fdp" , "osage" , "patchwork" }
173- else :
174- return graphviz .ENGINES # type: ignore[no-any-return]
175-
176-
177150class Command (BaseCommand ):
178151 help = "Creates a GraphViz dot file with transitions for selected fields"
179152
@@ -192,7 +165,7 @@ def add_arguments(self, parser: ArgumentParser) -> None:
192165 action = "store" ,
193166 dest = "layout" ,
194167 default = "dot" ,
195- help = f"Layout to be used by GraphViz for visualization: { get_graphviz_layouts () } ." ,
168+ help = f"Layout to be used by GraphViz for visualization: { graphviz . ENGINES } ." ,
196169 )
197170 parser .add_argument (
198171 "--exclude" ,
@@ -204,13 +177,6 @@ def add_arguments(self, parser: ArgumentParser) -> None:
204177 )
205178 parser .add_argument ("args" , nargs = "*" , help = ("[appname[.model[.field]]]" ))
206179
207- def render_output (self , graph : graphviz .Digraph , ** options : typing .Any ) -> None :
208- filename , graph_format = options ["outputfile" ].rsplit ("." , 1 )
209-
210- graph .engine = options ["layout" ]
211- graph .format = graph_format
212- graph .render (filename )
213-
214180 def handle (self , * args : str , ** options : typing .Any ) -> None :
215181 fields_data : list [tuple [FSMFieldMixin , type [models .Model ]]] = []
216182 if args :
@@ -232,8 +198,11 @@ def handle(self, *args: str, **options: typing.Any) -> None:
232198
233199 dotdata = generate_dot (fields_data , ignore_transitions = options ["exclude" ].split ("," ))
234200
235- outputfile = options ["outputfile" ]
236- if outputfile :
237- self .render_output (dotdata , ** options )
201+ if outputfile := options ["outputfile" ]:
202+ filename , graph_format = outputfile .rsplit ("." , 1 )
203+
204+ dotdata .engine = options ["layout" ]
205+ dotdata .format = graph_format
206+ dotdata .render (filename )
238207 else :
239208 self .stdout .write (str (dotdata ))
0 commit comments