Skip to content

Commit 2889378

Browse files
authored
tests: add dataflow test cases (#6053)
This is the start of test cases aimed to cover more dataflow cases not covered in tests (especially SQL). There are 2 existing xfail which should be fixed by #5992
1 parent 151dbe0 commit 2889378

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# Copyright 2024 Marimo. All rights reserved.
2+
from __future__ import annotations
3+
4+
from dataclasses import dataclass
5+
from functools import partial
6+
from typing import TYPE_CHECKING, Optional, Union
7+
8+
import pytest
9+
10+
from marimo._ast import compiler
11+
from marimo._dependencies.dependencies import DependencyManager
12+
from marimo._runtime import dataflow
13+
from marimo._types.ids import CellId_t
14+
15+
parse_cell = partial(compiler.compile_cell, cell_id=CellId_t("0"))
16+
17+
HAS_DUCKDB = DependencyManager.duckdb.has()
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import Iterable
21+
22+
23+
@dataclass
24+
class GraphTestCase:
25+
"""A test case for dataflow graph operations."""
26+
27+
# Test description
28+
name: str
29+
30+
# If enabled
31+
32+
# Code to create and register
33+
code: dict[str, str]
34+
35+
# Expected graph structure
36+
expected_parents: Optional[dict[str, Iterable[str]]] = None
37+
expected_children: Optional[dict[str, Iterable[str]]] = None
38+
expected_stale: Optional[Iterable[str]] = None
39+
40+
# Expected refs/defs
41+
expected_refs: Optional[dict[str, Iterable[str]]] = None
42+
expected_defs: Optional[dict[str, Iterable[str]]] = None
43+
44+
enabled: bool = True
45+
xfail: Union[bool, str] = False
46+
47+
def __post_init__(self) -> None:
48+
# Convert all to a []
49+
if self.expected_parents is not None:
50+
self.expected_parents = {
51+
cell_id: set(parents)
52+
for cell_id, parents in self.expected_parents.items()
53+
}
54+
if self.expected_children is not None:
55+
self.expected_children = {
56+
cell_id: set(children)
57+
for cell_id, children in self.expected_children.items()
58+
}
59+
if self.expected_stale is not None:
60+
self.expected_stale = set(self.expected_stale)
61+
if self.expected_refs is not None:
62+
self.expected_refs = {
63+
cell_id: set(refs)
64+
for cell_id, refs in self.expected_refs.items()
65+
}
66+
if self.expected_defs is not None:
67+
self.expected_defs = {
68+
cell_id: set(defs)
69+
for cell_id, defs in self.expected_defs.items()
70+
}
71+
72+
73+
CASES = [
74+
# Basic Python Cases
75+
GraphTestCase(
76+
name="single node",
77+
code={"0": "x = 0"},
78+
expected_parents={"0": []},
79+
expected_children={"0": []},
80+
expected_refs={"0": []},
81+
expected_defs={"0": ["x"]},
82+
),
83+
GraphTestCase(
84+
name="chain",
85+
code={"0": "x = 0", "1": "y = x", "2": "z = y\nzz = x"},
86+
expected_parents={"0": [], "1": ["0"], "2": ["0", "1"]},
87+
expected_children={"0": ["1", "2"], "1": ["2"], "2": []},
88+
expected_refs={"0": [], "1": ["x"], "2": ["x", "y"]},
89+
expected_defs={
90+
"0": ["x"],
91+
"1": ["y"],
92+
"2": ["z", "zz"],
93+
},
94+
),
95+
GraphTestCase(
96+
name="cycle",
97+
code={"0": "x = y", "1": "y = x"},
98+
expected_parents={"0": ["1"], "1": ["0"]},
99+
expected_children={"0": ["1"], "1": ["0"]},
100+
expected_refs={"0": ["y"], "1": ["x"]},
101+
expected_defs={"0": ["x"], "1": ["y"]},
102+
),
103+
GraphTestCase(
104+
name="diamond",
105+
code={
106+
"0": "x = 0",
107+
"1": "y = x",
108+
"2": "z = y\nzz = x",
109+
"3": "a = z",
110+
},
111+
expected_parents={
112+
"0": [],
113+
"1": ["0"],
114+
"2": ["0", "1"],
115+
"3": ["2"],
116+
},
117+
expected_children={
118+
"0": ["1", "2"],
119+
"1": ["2"],
120+
"2": ["3"],
121+
"3": [],
122+
},
123+
expected_refs={
124+
"0": [],
125+
"1": ["x"],
126+
"2": ["x", "y"],
127+
"3": ["z"],
128+
},
129+
expected_defs={
130+
"0": ["x"],
131+
"1": ["y"],
132+
"2": ["z", "zz"],
133+
"3": ["a"],
134+
},
135+
),
136+
GraphTestCase(
137+
name="variable del",
138+
code={"0": "x = 0", "1": "y = x", "2": "del x"},
139+
expected_parents={"0": [], "1": ["0"], "2": ["0", "1"]},
140+
expected_children={"0": ["1", "2"], "1": ["2"], "2": []},
141+
expected_refs={"0": [], "1": ["x"], "2": ["x"]},
142+
expected_defs={
143+
"0": ["x"],
144+
"1": ["y"],
145+
"2": [],
146+
},
147+
),
148+
# SQL Cases
149+
GraphTestCase(
150+
name="python -> sql",
151+
code={
152+
"0": "df = pd.read_csv('data.csv')",
153+
"1": "result = mo.sql(f'FROM df WHERE name = {name}')",
154+
},
155+
expected_parents={"0": [], "1": ["0"]},
156+
expected_children={"0": ["1"], "1": []},
157+
expected_refs={"0": ["pd"], "1": ["df", "mo", "name"]},
158+
expected_defs={"0": ["df"], "1": ["result"]},
159+
),
160+
GraphTestCase(
161+
name="sql -> python via output",
162+
code={
163+
"0": "result = mo.sql(f'FROM my_table WHERE name = {name}')",
164+
"1": "df = result.head()",
165+
},
166+
expected_parents={"0": [], "1": ["0"]},
167+
expected_children={"0": ["1"], "1": []},
168+
expected_refs={"0": ["mo", "name", "my_table"], "1": ["result"]},
169+
expected_defs={"0": ["result"], "1": ["df"]},
170+
),
171+
GraphTestCase(
172+
name="sql -/> python when creating a table",
173+
code={
174+
"0": "_ = mo.sql(f'CREATE TABLE my_table (name STRING)')",
175+
"1": "my_table = df.head()",
176+
},
177+
expected_parents={"0": [], "1": []},
178+
expected_children={"0": [], "1": []},
179+
expected_refs={"0": ["mo"], "1": ["df"]},
180+
expected_defs={"0": ["my_table"], "1": ["my_table"]},
181+
),
182+
GraphTestCase(
183+
name="sql redefinition",
184+
code={
185+
"0": "df = pd.read_csv('data.csv')",
186+
"1": "df = mo.sql(f'FROM df')",
187+
},
188+
expected_parents={"0": [], "1": ["0"]},
189+
expected_children={"0": ["1"], "1": []},
190+
expected_refs={"0": ["pd"], "1": ["df", "mo"]},
191+
expected_defs={"0": ["df"], "1": ["df"]},
192+
),
193+
GraphTestCase(
194+
name="python and sql not related because has schema",
195+
enabled=HAS_DUCKDB,
196+
code={
197+
"0": "df = pd.read_csv('data.csv')",
198+
"1": "result = mo.sql(f'FROM my_schema.df')",
199+
},
200+
expected_parents={"0": [], "1": []},
201+
expected_children={"0": [], "1": []},
202+
expected_refs={"0": ["pd"], "1": ["df", "mo", "my_schema"]},
203+
# This is correct
204+
# expected_refs={"0": ["pd"], "1": ["df.my_schema", "mo"]},
205+
expected_defs={"0": ["df"], "1": ["result"]},
206+
xfail=True,
207+
),
208+
GraphTestCase(
209+
name="sql should not reference python variables when schema",
210+
enabled=HAS_DUCKDB,
211+
code={
212+
"0": "my_schema = 100",
213+
"1": "_ = mo.sql(f'FROM my_schema.df')",
214+
},
215+
expected_parents={"0": [], "1": []},
216+
expected_children={"0": [], "1": []},
217+
expected_refs={"0": [], "1": ["mo", "my_schema", "df"]},
218+
# This is correct
219+
# expected_refs={"0": ["pd"], "1": ["my_schema.df", "mo"]},
220+
expected_defs={"0": ["my_schema"], "1": []},
221+
xfail=True,
222+
),
223+
GraphTestCase(
224+
name="sql should not reference python variables when catalog",
225+
enabled=HAS_DUCKDB,
226+
code={
227+
"0": "my_catalog = 100",
228+
"1": "_ = mo.sql(f'FROM my_catalog.my_schema.df')",
229+
},
230+
expected_parents={"0": [], "1": []},
231+
expected_children={"0": [], "1": []},
232+
expected_refs={"0": [], "1": ["mo", "my_catalog", "my_schema", "df"]},
233+
# This is correct
234+
# expected_refs={"0": ["pd"], "1": ["my_catalog.my_schema.df", "mo"]},
235+
expected_defs={"0": ["my_catalog"], "1": []},
236+
xfail=True,
237+
),
238+
]
239+
240+
241+
@pytest.mark.parametrize("case", CASES)
242+
def test_cases(case: GraphTestCase) -> None:
243+
print(f"Running {case.name}")
244+
graph = dataflow.DirectedGraph()
245+
246+
if not case.enabled:
247+
pytest.skip(f"Skipping {case.name} because it's not enabled")
248+
249+
for cell_id, code in case.code.items():
250+
cell = parse_cell(code)
251+
graph.register_cell(CellId_t(cell_id), cell)
252+
253+
def make_assertions():
254+
if case.expected_refs:
255+
for cell_id, refs in case.expected_refs.items():
256+
assert graph.cells[CellId_t(cell_id)].refs == refs, (
257+
f"Cell {cell_id} has refs {graph.cells[CellId_t(cell_id)].refs}, expected {refs}"
258+
)
259+
if case.expected_defs:
260+
for cell_id, defs in case.expected_defs.items():
261+
assert graph.cells[CellId_t(cell_id)].defs == defs, (
262+
f"Cell {cell_id} has defs {graph.cells[CellId_t(cell_id)].defs}, expected {defs}"
263+
)
264+
assert graph.parents == case.expected_parents, (
265+
f"Graph parents {graph.parents} do not match expected {case.expected_parents}"
266+
)
267+
assert graph.children == case.expected_children, (
268+
f"Graph children {graph.children} do not match expected {case.expected_children}"
269+
)
270+
271+
if case.xfail:
272+
if isinstance(case.xfail, str):
273+
print(case.xfail)
274+
with pytest.raises(AssertionError):
275+
make_assertions()
276+
else:
277+
make_assertions()

0 commit comments

Comments
 (0)