|
| 1 | +import importlib |
| 2 | +from types import ModuleType |
1 | 3 | from typing import Any, Dict, List, Union
|
2 |
| -from unittest.mock import patch |
3 | 4 |
|
| 5 | +import pytest |
4 | 6 | from sqlmodel import create_engine
|
5 | 7 |
|
6 |
| -from tests.conftest import get_testing_print_function |
| 8 | +from tests.conftest import PrintMock, needs_py310 |
7 | 9 |
|
8 | 10 |
|
9 |
| -def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]): |
| 11 | +def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None: |
10 | 12 | assert calls[0] == ["Before interacting with the database"]
|
11 | 13 | assert calls[1] == [
|
12 | 14 | "Hero 1:",
|
@@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
|
133 | 135 | ]
|
134 | 136 |
|
135 | 137 |
|
136 |
| -def test_tutorial_001(): |
137 |
| - from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod |
| 138 | +@pytest.fixture( |
| 139 | + name="module", |
| 140 | + params=[ |
| 141 | + "tutorial001", |
| 142 | + "tutorial002", |
| 143 | + pytest.param("tutorial001_py310", marks=needs_py310), |
| 144 | + pytest.param("tutorial002_py310", marks=needs_py310), |
| 145 | + ], |
| 146 | +) |
| 147 | +def get_module(request: pytest.FixtureRequest) -> ModuleType: |
| 148 | + module = importlib.import_module( |
| 149 | + f"docs_src.tutorial.automatic_id_none_refresh.{request.param}" |
| 150 | + ) |
| 151 | + module.sqlite_url = "sqlite://" |
| 152 | + module.engine = create_engine(module.sqlite_url) |
138 | 153 |
|
139 |
| - mod.sqlite_url = "sqlite://" |
140 |
| - mod.engine = create_engine(mod.sqlite_url) |
141 |
| - calls = [] |
| 154 | + return module |
142 | 155 |
|
143 |
| - new_print = get_testing_print_function(calls) |
144 | 156 |
|
145 |
| - with patch("builtins.print", new=new_print): |
146 |
| - mod.main() |
147 |
| - check_calls(calls) |
148 |
| - |
149 |
| - |
150 |
| -def test_tutorial_002(): |
151 |
| - from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod |
152 |
| - |
153 |
| - mod.sqlite_url = "sqlite://" |
154 |
| - mod.engine = create_engine(mod.sqlite_url) |
155 |
| - calls = [] |
156 |
| - |
157 |
| - new_print = get_testing_print_function(calls) |
158 |
| - |
159 |
| - with patch("builtins.print", new=new_print): |
160 |
| - mod.main() |
161 |
| - check_calls(calls) |
| 157 | +def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None: |
| 158 | + module.main() |
| 159 | + check_calls(print_mock.calls) |
0 commit comments