Skip to content

Commit f1c9d15

Browse files
authored
✅ Simplify tests setup, one test file for multiple source variants (#1407)
1 parent af52d65 commit f1c9d15

File tree

3 files changed

+43
-193
lines changed

3 files changed

+43
-193
lines changed

tests/conftest.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import shutil
22
import subprocess
33
import sys
4+
from dataclasses import dataclass, field
45
from pathlib import Path
5-
from typing import Any, Callable, Dict, List, Union
6+
from typing import Any, Callable, Dict, Generator, List, Union
7+
from unittest.mock import patch
68

79
import pytest
810
from pydantic import BaseModel
@@ -26,7 +28,7 @@ def clear_sqlmodel() -> Any:
2628

2729

2830
@pytest.fixture()
29-
def cov_tmp_path(tmp_path: Path):
31+
def cov_tmp_path(tmp_path: Path) -> Generator[Path, None, None]:
3032
yield tmp_path
3133
for coverage_path in tmp_path.glob(".coverage*"):
3234
coverage_destiny_path = top_level_path / coverage_path.name
@@ -53,8 +55,8 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP
5355
def get_testing_print_function(
5456
calls: List[List[Union[str, Dict[str, Any]]]],
5557
) -> Callable[..., Any]:
56-
def new_print(*args):
57-
data = []
58+
def new_print(*args: Any) -> None:
59+
data: List[Any] = []
5860
for arg in args:
5961
if isinstance(arg, BaseModel):
6062
data.append(arg.model_dump())
@@ -71,6 +73,19 @@ def new_print(*args):
7173
return new_print
7274

7375

76+
@dataclass
77+
class PrintMock:
78+
calls: List[Any] = field(default_factory=list)
79+
80+
81+
@pytest.fixture(name="print_mock")
82+
def print_mock_fixture() -> Generator[PrintMock, None, None]:
83+
print_mock = PrintMock()
84+
new_print = get_testing_print_function(print_mock.calls)
85+
with patch("builtins.print", new=new_print):
86+
yield print_mock
87+
88+
7489
needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2")
7590
needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1")
7691

tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py

Lines changed: 0 additions & 163 deletions
This file was deleted.

tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import importlib
2+
from types import ModuleType
13
from typing import Any, Dict, List, Union
2-
from unittest.mock import patch
34

5+
import pytest
46
from sqlmodel import create_engine
57

6-
from tests.conftest import get_testing_print_function
8+
from tests.conftest import PrintMock, needs_py310
79

810

9-
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
11+
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None:
1012
assert calls[0] == ["Before interacting with the database"]
1113
assert calls[1] == [
1214
"Hero 1:",
@@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
133135
]
134136

135137

136-
def test_tutorial_001():
137-
from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod
138+
@pytest.fixture(
139+
name="module",
140+
params=[
141+
"tutorial001",
142+
"tutorial002",
143+
pytest.param("tutorial001_py310", marks=needs_py310),
144+
pytest.param("tutorial002_py310", marks=needs_py310),
145+
],
146+
)
147+
def get_module(request: pytest.FixtureRequest) -> ModuleType:
148+
module = importlib.import_module(
149+
f"docs_src.tutorial.automatic_id_none_refresh.{request.param}"
150+
)
151+
module.sqlite_url = "sqlite://"
152+
module.engine = create_engine(module.sqlite_url)
138153

139-
mod.sqlite_url = "sqlite://"
140-
mod.engine = create_engine(mod.sqlite_url)
141-
calls = []
154+
return module
142155

143-
new_print = get_testing_print_function(calls)
144156

145-
with patch("builtins.print", new=new_print):
146-
mod.main()
147-
check_calls(calls)
148-
149-
150-
def test_tutorial_002():
151-
from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod
152-
153-
mod.sqlite_url = "sqlite://"
154-
mod.engine = create_engine(mod.sqlite_url)
155-
calls = []
156-
157-
new_print = get_testing_print_function(calls)
158-
159-
with patch("builtins.print", new=new_print):
160-
mod.main()
161-
check_calls(calls)
157+
def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None:
158+
module.main()
159+
check_calls(print_mock.calls)

0 commit comments

Comments
 (0)