Skip to content

Commit 5fd3982

Browse files
committed
✅ Refactor test to use new print mock and parametrize import modules, one test file for 4 source files
1 parent a00492a commit 5fd3982

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import importlib
2+
from types import ModuleType
13
from typing import Any, Dict, List, Union
2-
from unittest.mock import patch
34

5+
import pytest
46
from sqlmodel import create_engine
57

6-
from tests.conftest import get_testing_print_function
8+
from tests.conftest import PrintMock, needs_py310
79

810

9-
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
11+
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None:
1012
assert calls[0] == ["Before interacting with the database"]
1113
assert calls[1] == [
1214
"Hero 1:",
@@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
133135
]
134136

135137

136-
def test_tutorial_001():
137-
from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod
138+
@pytest.fixture(
139+
name="module",
140+
params=[
141+
"tutorial001",
142+
"tutorial002",
143+
pytest.param("tutorial001_py310", marks=needs_py310),
144+
pytest.param("tutorial002_py310", marks=needs_py310),
145+
],
146+
)
147+
def get_module(request: pytest.FixtureRequest) -> ModuleType:
148+
module = importlib.import_module(
149+
f"docs_src.tutorial.automatic_id_none_refresh.{request.param}"
150+
)
151+
module.sqlite_url = "sqlite://"
152+
module.engine = create_engine(module.sqlite_url)
138153

139-
mod.sqlite_url = "sqlite://"
140-
mod.engine = create_engine(mod.sqlite_url)
141-
calls = []
154+
return module
142155

143-
new_print = get_testing_print_function(calls)
144156

145-
with patch("builtins.print", new=new_print):
146-
mod.main()
147-
check_calls(calls)
148-
149-
150-
def test_tutorial_002():
151-
from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod
152-
153-
mod.sqlite_url = "sqlite://"
154-
mod.engine = create_engine(mod.sqlite_url)
155-
calls = []
156-
157-
new_print = get_testing_print_function(calls)
158-
159-
with patch("builtins.print", new=new_print):
160-
mod.main()
161-
check_calls(calls)
157+
def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None:
158+
module.main()
159+
check_calls(print_mock.calls)

0 commit comments

Comments
 (0)