Skip to content

Commit df5a4c9

Browse files
committed
add test for diagram
1 parent 09ecb2d commit df5a4c9

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

tests/test_mermaid.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
from pytensor import function
6+
from pytensor import tensor as pt
7+
from pytensor.mermaid import function_to_mermaid
8+
9+
10+
@pytest.fixture
11+
def sample_function():
12+
x = pt.dmatrix("x")
13+
y = pt.dvector("y")
14+
z = pt.dot(x, y)
15+
z.name = "z"
16+
return function([x, y], z)
17+
18+
19+
def test_function_to_mermaid(sample_function):
20+
diagram = function_to_mermaid(sample_function)
21+
22+
assert (
23+
diagram
24+
== dedent("""
25+
graph TD
26+
%% Nodes:
27+
n1["Shape_i"]
28+
n1@{ shape: rounded }
29+
style n1 fill:#00FFFF
30+
n2["x"]
31+
n2@{ shape: rect }
32+
style n2 fill:#32CD32
33+
n2["x"]
34+
n2@{ shape: rect }
35+
style n2 fill:#32CD32
36+
n4["AllocEmpty"]
37+
n4@{ shape: rounded }
38+
n6["CGemv"]
39+
n6@{ shape: rounded }
40+
n7["1.0"]
41+
n7@{ shape: rect }
42+
style n7 fill:#00FF7F
43+
n8["y"]
44+
n8@{ shape: rect }
45+
style n8 fill:#32CD32
46+
n9["0.0"]
47+
n9@{ shape: rect }
48+
style n9 fill:#00FF7F
49+
n10["z"]
50+
n10@{ shape: rect }
51+
style n10 fill:#1E90FF
52+
53+
%% Edges:
54+
n2 --> n1
55+
n1 --> n4
56+
n4 --> n6
57+
n7 --> n6
58+
n2 --> n6
59+
n8 --> n6
60+
n9 --> n6
61+
n6 --> n10
62+
""").strip()
63+
)

0 commit comments

Comments
 (0)