Skip to content

Commit 10858a9

Browse files
Implements Support to Control Direction of State Diagram Generated by Mermaid Code (#716)
Co-authored-by: Sydney Runkle <[email protected]>
1 parent 81bf883 commit 10858a9

File tree

4 files changed

+93
-2
lines changed

4 files changed

+93
-2
lines changed

docs/graph.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,3 +790,35 @@ stateDiagram-v2
790790
classDef highlighted fill:#fdff32
791791
class Answer highlighted
792792
```
793+
794+
### Setting Direction of the State Diagram
795+
796+
You can specify the direction of the state diagram using one of the following values:
797+
798+
- `'TB'`: Top to bottom, the diagram flows vertically from top to bottom.
799+
- `'LR'`: Left to right, the diagram flows horizontally from left to right.
800+
- `'RL'`: Right to left, the diagram flows horizontally from right to left.
801+
- `'BT'`: Bottom to top, the diagram flows vertically from bottom to top.
802+
803+
Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB)
804+
```py {title="vending_machine_diagram.py" py="3.10"}
805+
from vending_machine import InsertCoin, vending_machine_graph
806+
807+
vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR')
808+
```
809+
810+
```mermaid
811+
---
812+
title: vending_machine_graph
813+
---
814+
stateDiagram-v2
815+
direction LR
816+
[*] --> InsertCoin
817+
InsertCoin --> CoinsInserted
818+
CoinsInserted --> SelectProduct
819+
CoinsInserted --> Purchase
820+
SelectProduct --> Purchase
821+
Purchase --> InsertCoin
822+
Purchase --> SelectProduct
823+
Purchase --> [*]
824+
```

pydantic_graph/pydantic_graph/graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def mermaid_code(
286286
highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
287287
highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
288288
infer_name: bool = True,
289+
direction: mermaid.StateDiagramDirection | None = None,
289290
) -> str:
290291
"""Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram.
291292
@@ -299,6 +300,7 @@ def mermaid_code(
299300
highlighted_nodes: Optional node or nodes to highlight.
300301
highlight_css: The CSS to use for highlighting nodes.
301302
infer_name: Whether to infer the graph name from the calling frame.
303+
direction: The direction of flow.
302304
303305
Returns:
304306
The mermaid code for the graph, which can then be rendered as a diagram.
@@ -346,6 +348,7 @@ def mermaid_code(
346348
title=title or None,
347349
edge_labels=edge_labels,
348350
notes=notes,
351+
direction=direction,
349352
)
350353

351354
def mermaid_image(

pydantic_graph/pydantic_graph/mermaid.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,27 @@
1616
if TYPE_CHECKING:
1717
from .graph import Graph
1818

19-
20-
__all__ = 'NodeIdent', 'DEFAULT_HIGHLIGHT_CSS', 'generate_code', 'MermaidConfig', 'request_image', 'save_image'
19+
__all__ = (
20+
'NodeIdent',
21+
'DEFAULT_HIGHLIGHT_CSS',
22+
'generate_code',
23+
'MermaidConfig',
24+
'request_image',
25+
'save_image',
26+
'StateDiagramDirection',
27+
)
2128
DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
2229
"""The default CSS to use for highlighting nodes."""
2330

31+
StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT']
32+
"""Used to specify the direction of the state diagram generated by mermaid.
33+
34+
- `'TB'`: Top to bottom, this is the default for mermaid charts.
35+
- `'LR'`: Left to right
36+
- `'RL'`: Right to left
37+
- `'BT'`: Bottom to top
38+
"""
39+
2440

2541
def generate_code( # noqa: C901
2642
graph: Graph[Any, Any, Any],
@@ -32,6 +48,7 @@ def generate_code( # noqa: C901
3248
title: str | None = None,
3349
edge_labels: bool = True,
3450
notes: bool = True,
51+
direction: StateDiagramDirection | None,
3552
) -> str:
3653
"""Generate [Mermaid state diagram](https://mermaid.js.org/syntax/stateDiagram.html) code for a graph.
3754
@@ -43,6 +60,8 @@ def generate_code( # noqa: C901
4360
title: The title of the diagram.
4461
edge_labels: Whether to include edge labels in the diagram.
4562
notes: Whether to include notes in the diagram.
63+
direction: The direction of flow.
64+
4665
4766
Returns:
4867
The Mermaid code for the graph.
@@ -56,6 +75,8 @@ def generate_code( # noqa: C901
5675
if title:
5776
lines = ['---', f'title: {title}', '---']
5877
lines.append('stateDiagram-v2')
78+
if direction is not None:
79+
lines.append(f' direction {direction}')
5980
for node_id, node_def in graph.node_defs.items():
6081
# we use round brackets (rounded box) for nodes other than the start and end
6182
if node_id in start_node_ids:
@@ -131,6 +152,7 @@ def request_image(
131152
title=kwargs.get('title'),
132153
edge_labels=kwargs.get('edge_labels', True),
133154
notes=kwargs.get('notes', True),
155+
direction=kwargs.get('direction'),
134156
)
135157
code_base64 = base64.b64encode(code.encode()).decode()
136158

@@ -245,6 +267,8 @@ class MermaidConfig(TypedDict, total=False):
245267
"""
246268
httpx_client: httpx.Client
247269
"""An HTTPX client to use for requests, mostly for testing purposes."""
270+
direction: StateDiagramDirection
271+
"""The direction of the state diagram."""
248272

249273

250274
NodeIdent: TypeAlias = 'type[BaseNode[Any, Any, Any]] | BaseNode[Any, Any, Any] | str'

tests/graph/test_mermaid.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,38 @@ def test_mermaid_code_all_nodes():
194194
""")
195195

196196

197+
def test_mermaid_code_all_nodes_no_direction():
198+
assert graph3.mermaid_code() == snapshot("""\
199+
---
200+
title: graph3
201+
---
202+
stateDiagram-v2
203+
AllNodes --> AllNodes
204+
AllNodes --> Foo
205+
AllNodes --> Bar
206+
Foo --> Bar
207+
Bar --> [*]\
208+
""")
209+
210+
211+
def test_mermaid_code_all_nodes_with_direction_lr():
212+
assert graph3.mermaid_code(direction='LR') == snapshot("""\
213+
---
214+
title: graph3
215+
---
216+
stateDiagram-v2
217+
direction LR
218+
AllNodes --> AllNodes
219+
AllNodes --> Foo
220+
AllNodes --> Bar
221+
Foo --> Bar
222+
Bar --> [*]\
223+
""")
224+
225+
226+
# Tests for direction ends here
227+
228+
197229
def test_docstring_notes_classvar():
198230
assert Spam.docstring_notes is True
199231
assert repr(Spam()) == 'Spam()'

0 commit comments

Comments
 (0)