Skip to content

Commit d0c304d

Browse files
authored
nixos/test-driver: improve error reporting and assertions (#390996)
2 parents 04addb2 + deff22b commit d0c304d

File tree

9 files changed

+118
-23
lines changed

9 files changed

+118
-23
lines changed

nixos/doc/manual/development/writing-nixos-tests.section.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ and checks that the output is more-or-less correct:
121121
```py
122122
machine.start()
123123
machine.wait_for_unit("default.target")
124-
if not "Linux" in machine.succeed("uname"):
125-
raise Exception("Wrong OS")
124+
t.assertIn("Linux", machine.succeed("uname"), "Wrong OS")
126125
```
127126

128127
The first line is technically unnecessary; machines are implicitly started
@@ -134,6 +133,8 @@ starting them in parallel:
134133
start_all()
135134
```
136135

136+
Under the variable `t`, all assertions from [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html) are available.
137+
137138
If the hostname of a node contains characters that can't be used in a
138139
Python variable name, those characters will be replaced with
139140
underscores in the variable name, so `nodes.machine-a` will be exposed

nixos/lib/test-driver/default.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python3Packages.buildPythonApplication {
3131
colorama
3232
junit-xml
3333
ptpython
34+
ipython
3435
]
3536
++ extraPythonPackages python3Packages;
3637

nixos/lib/test-driver/src/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ target-version = "py312"
2121
line-length = 88
2222

2323
lint.select = ["E", "F", "I", "U", "N"]
24-
lint.ignore = ["E501"]
24+
lint.ignore = ["E501", "N818"]
2525

2626
# xxx: we can import https://pypi.org/project/types-colorama/ here
2727
[[tool.mypy.overrides]]

nixos/lib/test-driver/src/test_driver/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from pathlib import Path
55

6-
import ptpython.repl
6+
import ptpython.ipython
77

88
from test_driver.driver import Driver
99
from test_driver.logger import (
@@ -136,11 +136,10 @@ def main() -> None:
136136
if args.interactive:
137137
history_dir = os.getcwd()
138138
history_path = os.path.join(history_dir, ".nixos-test-history")
139-
ptpython.repl.embed(
140-
driver.test_symbols(),
141-
{},
139+
ptpython.ipython.embed(
140+
user_ns=driver.test_symbols(),
142141
history_filename=history_path,
143-
)
142+
) # type:ignore
144143
else:
145144
tic = time.time()
146145
driver.run_tests()

nixos/lib/test-driver/src/test_driver/driver.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import os
22
import re
33
import signal
4+
import sys
45
import tempfile
56
import threading
7+
import traceback
68
from collections.abc import Callable, Iterator
79
from contextlib import AbstractContextManager, contextmanager
810
from pathlib import Path
911
from typing import Any
12+
from unittest import TestCase
1013

14+
from test_driver.errors import MachineError, RequestedAssertionFailed
1115
from test_driver.logger import AbstractLogger
1216
from test_driver.machine import Machine, NixStartScript, retry
1317
from test_driver.polling_condition import PollingCondition
@@ -16,6 +20,18 @@
1620
SENTINEL = object()
1721

1822

23+
class AssertionTester(TestCase):
24+
"""
25+
Subclass of `unittest.TestCase` which is used in the
26+
`testScript` to perform assertions.
27+
28+
It throws a custom exception whose parent class
29+
gets special treatment in the logs.
30+
"""
31+
32+
failureException = RequestedAssertionFailed
33+
34+
1935
def get_tmp_dir() -> Path:
2036
"""Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
2137
Raises an exception in case the retrieved temporary directory is not writeable
@@ -115,7 +131,7 @@ def subtest(self, name: str) -> Iterator[None]:
115131
try:
116132
yield
117133
except Exception as e:
118-
self.logger.error(f'Test "{name}" failed with error: "{e}"')
134+
self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"')
119135
raise e
120136

121137
def test_symbols(self) -> dict[str, Any]:
@@ -140,6 +156,7 @@ def subtest(name: str) -> Iterator[None]:
140156
serial_stdout_on=self.serial_stdout_on,
141157
polling_condition=self.polling_condition,
142158
Machine=Machine, # for typing
159+
t=AssertionTester(),
143160
)
144161
machine_symbols = {pythonize_name(m.name): m for m in self.machines}
145162
# If there's exactly one machine, make it available under the name
@@ -163,7 +180,36 @@ def test_script(self) -> None:
163180
"""Run the test script"""
164181
with self.logger.nested("run the VM test script"):
165182
symbols = self.test_symbols() # call eagerly
166-
exec(self.tests, symbols, None)
183+
try:
184+
exec(self.tests, symbols, None)
185+
except MachineError:
186+
for line in traceback.format_exc().splitlines():
187+
self.logger.log_test_error(line)
188+
sys.exit(1)
189+
except RequestedAssertionFailed:
190+
exc_type, exc, tb = sys.exc_info()
191+
# We manually print the stack frames, keeping only the ones from the test script
192+
# (note: because the script is not a real file, the frame filename is `<string>`)
193+
filtered = [
194+
frame
195+
for frame in traceback.extract_tb(tb)
196+
if frame.filename == "<string>"
197+
]
198+
199+
self.logger.log_test_error("Traceback (most recent call last):")
200+
201+
code = self.tests.splitlines()
202+
for frame, line in zip(filtered, traceback.format_list(filtered)):
203+
self.logger.log_test_error(line.rstrip())
204+
if lineno := frame.lineno:
205+
self.logger.log_test_error(f" {code[lineno - 1].strip()}")
206+
207+
self.logger.log_test_error("") # blank line for readability
208+
exc_prefix = exc_type.__name__ if exc_type is not None else "Error"
209+
for line in f"{exc_prefix}: {exc}".splitlines():
210+
self.logger.log_test_error(line)
211+
212+
sys.exit(1)
167213

168214
def run_tests(self) -> None:
169215
"""Run the test script (for non-interactive test runs)"""
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class MachineError(Exception):
2+
"""
3+
Exception that indicates an error that is NOT the user's fault,
4+
i.e. something went wrong without the test being necessarily invalid,
5+
such as failing OCR.
6+
7+
To make it easier to spot, this exception (and its subclasses)
8+
get a `!!!` prefix in the log output.
9+
"""
10+
11+
12+
class RequestedAssertionFailed(AssertionError):
13+
"""
14+
Special assertion that gets thrown on an assertion error,
15+
e.g. a failing `t.assertEqual(...)` or `machine.succeed(...)`.
16+
17+
This gets special treatment in error reporting: i.e. it gets
18+
`!!!` as prefix just as `MachineError`, but only stack frames coming
19+
from `testScript` will show up in logs.
20+
"""

nixos/lib/test-driver/src/test_driver/logger.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def warning(self, *args, **kwargs) -> None: # type: ignore
4444
def error(self, *args, **kwargs) -> None: # type: ignore
4545
pass
4646

47+
@abstractmethod
48+
def log_test_error(self, *args, **kwargs) -> None: # type:ignore
49+
pass
50+
4751
@abstractmethod
4852
def log_serial(self, message: str, machine: str) -> None:
4953
pass
@@ -97,6 +101,9 @@ def error(self, *args, **kwargs) -> None: # type: ignore
97101
self.tests[self.currentSubtest].stderr += args[0] + os.linesep
98102
self.tests[self.currentSubtest].failure = True
99103

104+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
105+
self.error(*args, **kwargs)
106+
100107
def log_serial(self, message: str, machine: str) -> None:
101108
if not self._print_serial_logs:
102109
return
@@ -156,6 +163,10 @@ def warning(self, *args, **kwargs) -> None: # type: ignore
156163
for logger in self.logger_list:
157164
logger.warning(*args, **kwargs)
158165

166+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
167+
for logger in self.logger_list:
168+
logger.log_test_error(*args, **kwargs)
169+
159170
def error(self, *args, **kwargs) -> None: # type: ignore
160171
for logger in self.logger_list:
161172
logger.error(*args, **kwargs)
@@ -202,7 +213,7 @@ def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None
202213
tic = time.time()
203214
yield
204215
toc = time.time()
205-
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
216+
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes)
206217

207218
def info(self, *args, **kwargs) -> None: # type: ignore
208219
self.log(*args, **kwargs)
@@ -222,6 +233,11 @@ def log_serial(self, message: str, machine: str) -> None:
222233

223234
self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
224235

236+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
237+
prefix = Fore.RED + "!!! " + Style.RESET_ALL
238+
# NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log
239+
self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs)
240+
225241

226242
class XMLLogger(AbstractLogger):
227243
def __init__(self, outfile: str) -> None:
@@ -261,6 +277,9 @@ def warning(self, *args, **kwargs) -> None: # type: ignore
261277
def error(self, *args, **kwargs) -> None: # type: ignore
262278
self.log(*args, **kwargs)
263279

280+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
281+
self.log(*args, **kwargs)
282+
264283
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
265284
self.drain_log_queue()
266285
self.log_line(message, attributes)

nixos/lib/test-driver/src/test_driver/machine.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from queue import Queue
2020
from typing import Any
2121

22+
from test_driver.errors import MachineError, RequestedAssertionFailed
2223
from test_driver.logger import AbstractLogger
2324

2425
from .qmp import QMPSession
@@ -129,7 +130,7 @@ def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
129130
)
130131

131132
if ret.returncode != 0:
132-
raise Exception(
133+
raise MachineError(
133134
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
134135
)
135136

@@ -140,7 +141,7 @@ def _perform_ocr_on_screenshot(
140141
screenshot_path: str, model_ids: Iterable[int]
141142
) -> list[str]:
142143
if shutil.which("tesseract") is None:
143-
raise Exception("OCR requested but enableOCR is false")
144+
raise MachineError("OCR requested but enableOCR is false")
144145

145146
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
146147
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
@@ -163,7 +164,7 @@ def _perform_ocr_on_screenshot(
163164
capture_output=True,
164165
)
165166
if ret.returncode != 0:
166-
raise Exception(f"OCR failed with exit code {ret.returncode}")
167+
raise MachineError(f"OCR failed with exit code {ret.returncode}")
167168
model_results.append(ret.stdout.decode("utf-8"))
168169

169170
return model_results
@@ -180,7 +181,9 @@ def retry(fn: Callable, timeout: int = 900) -> None:
180181
time.sleep(1)
181182

182183
if not fn(True):
183-
raise Exception(f"action timed out after {timeout} seconds")
184+
raise RequestedAssertionFailed(
185+
f"action timed out after {timeout} tries with one-second pause in-between"
186+
)
184187

185188

186189
class StartCommand:
@@ -409,14 +412,14 @@ def wait_for_unit(
409412
def check_active(_last_try: bool) -> bool:
410413
state = self.get_unit_property(unit, "ActiveState", user)
411414
if state == "failed":
412-
raise Exception(f'unit "{unit}" reached state "{state}"')
415+
raise RequestedAssertionFailed(f'unit "{unit}" reached state "{state}"')
413416

414417
if state == "inactive":
415418
status, jobs = self.systemctl("list-jobs --full 2>&1", user)
416419
if "No jobs" in jobs:
417420
info = self.get_unit_info(unit, user)
418421
if info["ActiveState"] == state:
419-
raise Exception(
422+
raise RequestedAssertionFailed(
420423
f'unit "{unit}" is inactive and there are no pending jobs'
421424
)
422425

@@ -431,7 +434,7 @@ def check_active(_last_try: bool) -> bool:
431434
def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]:
432435
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
433436
if status != 0:
434-
raise Exception(
437+
raise RequestedAssertionFailed(
435438
f'retrieving systemctl info for unit "{unit}"'
436439
+ ("" if user is None else f' under user "{user}"')
437440
+ f" failed with exit code {status}"
@@ -461,7 +464,7 @@ def get_unit_property(
461464
user,
462465
)
463466
if status != 0:
464-
raise Exception(
467+
raise RequestedAssertionFailed(
465468
f'retrieving systemctl property "{property}" for unit "{unit}"'
466469
+ ("" if user is None else f' under user "{user}"')
467470
+ f" failed with exit code {status}"
@@ -509,7 +512,7 @@ def require_unit_state(self, unit: str, require_state: str = "active") -> None:
509512
info = self.get_unit_info(unit)
510513
state = info["ActiveState"]
511514
if state != require_state:
512-
raise Exception(
515+
raise RequestedAssertionFailed(
513516
f"Expected unit '{unit}' to to be in state "
514517
f"'{require_state}' but it is in state '{state}'"
515518
)
@@ -663,7 +666,9 @@ def succeed(self, *commands: str, timeout: int | None = None) -> str:
663666
(status, out) = self.execute(command, timeout=timeout)
664667
if status != 0:
665668
self.log(f"output: {out}")
666-
raise Exception(f"command `{command}` failed (exit code {status})")
669+
raise RequestedAssertionFailed(
670+
f"command `{command}` failed (exit code {status})"
671+
)
667672
output += out
668673
return output
669674

@@ -677,7 +682,9 @@ def fail(self, *commands: str, timeout: int | None = None) -> str:
677682
with self.nested(f"must fail: {command}"):
678683
(status, out) = self.execute(command, timeout=timeout)
679684
if status == 0:
680-
raise Exception(f"command `{command}` unexpectedly succeeded")
685+
raise RequestedAssertionFailed(
686+
f"command `{command}` unexpectedly succeeded"
687+
)
681688
output += out
682689
return output
683690

@@ -922,7 +929,7 @@ def screenshot(self, filename: str) -> None:
922929
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
923930
os.unlink(tmp)
924931
if ret.returncode != 0:
925-
raise Exception("Cannot convert screenshot")
932+
raise MachineError("Cannot convert screenshot")
926933

927934
def copy_from_host_via_shell(self, source: str, target: str) -> None:
928935
"""Copy a file from the host into the guest by piping it over the

nixos/lib/test-script-prepend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
99
from typing_extensions import Protocol
1010
from pathlib import Path
11+
from unittest import TestCase
1112

1213

1314
class RetryProtocol(Protocol):
@@ -51,3 +52,4 @@ def __call__(
5152
serial_stdout_off: Callable[[], None]
5253
serial_stdout_on: Callable[[], None]
5354
polling_condition: PollingConditionProtocol
55+
t: TestCase

0 commit comments

Comments
 (0)