Skip to content

Commit 8967635

Browse files
committed
Pytest plugins in their own file
1 parent 0510f27 commit 8967635

File tree

3 files changed

+50
-48
lines changed

3 files changed

+50
-48
lines changed

generated/test_class.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

main.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
import typer
33

44
# for the test runner
5-
import time
65
import pytest
76
# ------------------
87

98
import langroid as lr
109
from langroid.utils.configuration import set_global, Settings
1110
from langroid.utils.logging import setup_colored_logging
1211

13-
# import the empty generated code
14-
import generated.test_class
15-
# ------------------
12+
from pytest_plugins import ResultsCollector, SessionStartPlugin
1613

1714
app = typer.Typer()
1815
setup_colored_logging()
@@ -43,49 +40,6 @@ def generate_first_attempt() -> None:
4340
_out.write(response.content)
4441

4542

46-
class ResultsCollector:
47-
def __init__(self):
48-
self.reports = []
49-
self.collected = 0
50-
self.exitcode = 0
51-
self.passed = 0
52-
self.failed = 0
53-
self.xfailed = 0
54-
self.skipped = 0
55-
self.total_duration = 0
56-
57-
@pytest.hookimpl(hookwrapper=True)
58-
def pytest_runtest_makereport(self, item, call):
59-
outcome = yield
60-
report = outcome.get_result()
61-
if report.when == 'call':
62-
self.reports.append(report)
63-
64-
def pytest_collection_modifyitems(self, items):
65-
self.collected = len(items)
66-
67-
def pytest_terminal_summary(self, terminalreporter, exitstatus):
68-
self.exitcode = exitstatus
69-
self.passed = len(terminalreporter.stats.get('passed', []))
70-
self.failed = len(terminalreporter.stats.get('failed', []))
71-
self.xfailed = len(terminalreporter.stats.get('xfailed', []))
72-
self.skipped = len(terminalreporter.stats.get('skipped', []))
73-
74-
self.total_duration = time.time() - terminalreporter._sessionstarttime
75-
76-
77-
class SessionStartPlugin:
78-
"""
79-
The goal of this plugin is to allow us to run pytest multiple times
80-
and have it pick up the changes we generate in `generated/test_class.py`
81-
"""
82-
def pytest_sessionstart(self):
83-
if globals().get('generated', None) is not None:
84-
import importlib
85-
print("Reloading generated.test_class module...")
86-
importlib.reload(generated.test_class)
87-
88-
8943
def get_test_results() -> str:
9044
collector = ResultsCollector()
9145
setup = SessionStartPlugin()

pytest_plugins.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
import time
3+
try:
4+
import generated.test_class
5+
except ImportError:
6+
pass # Since this is a generated file, it sometimes doesn't exist.
7+
8+
9+
class ResultsCollector:
10+
def __init__(self):
11+
self.reports = []
12+
self.collected = 0
13+
self.exitcode = 0
14+
self.passed = 0
15+
self.failed = 0
16+
self.xfailed = 0
17+
self.skipped = 0
18+
self.total_duration = 0
19+
20+
@pytest.hookimpl(hookwrapper=True)
21+
def pytest_runtest_makereport(self, item, call):
22+
outcome = yield
23+
report = outcome.get_result()
24+
if report.when == 'call':
25+
self.reports.append(report)
26+
27+
def pytest_collection_modifyitems(self, items):
28+
self.collected = len(items)
29+
30+
def pytest_terminal_summary(self, terminalreporter, exitstatus):
31+
self.exitcode = exitstatus
32+
self.passed = len(terminalreporter.stats.get('passed', []))
33+
self.failed = len(terminalreporter.stats.get('failed', []))
34+
self.xfailed = len(terminalreporter.stats.get('xfailed', []))
35+
self.skipped = len(terminalreporter.stats.get('skipped', []))
36+
37+
self.total_duration = time.time() - terminalreporter._sessionstarttime
38+
39+
40+
class SessionStartPlugin:
41+
"""
42+
The goal of this plugin is to allow us to run pytest multiple times
43+
and have it pick up the changes we generate in `generated/test_class.py`
44+
"""
45+
def pytest_sessionstart(self):
46+
if globals().get('generated', None) is not None:
47+
import importlib
48+
print("Reloading generated.test_class module...")
49+
importlib.reload(generated.test_class)

0 commit comments

Comments
 (0)