Skip to content

Commit 665005b

Browse files
committed
test the model_to_mermaid
1 parent 04bdc71 commit 665005b

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_model_graph.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from textwrap import dedent
17+
1618
import numpy as np
1719
import pytensor
1820
import pytensor.tensor as pt
@@ -31,6 +33,7 @@
3133
NodeType,
3234
Plate,
3335
model_to_graphviz,
36+
model_to_mermaid,
3437
model_to_networkx,
3538
)
3639

@@ -629,3 +632,23 @@ def test_scalars_dim_info() -> None:
629632
]
630633

631634
assert graph.edges() == []
635+
636+
637+
def test_model_to_mermaid(simple_model):
638+
expected_mermaid_string = dedent("""
639+
graph TD
640+
%% Nodes:
641+
a([a ~ Normal])
642+
a@{ shape: rounded }
643+
b([b ~ Normal])
644+
b@{ shape: rounded }
645+
c([c ~ Normal])
646+
c@{ shape: rounded }
647+
648+
%% Edges:
649+
a --> b
650+
b --> c
651+
652+
%% Plates:
653+
""")
654+
assert model_to_mermaid(simple_model) == expected_mermaid_string.strip()

0 commit comments

Comments
 (0)