|
1 | | -import os |
2 | 1 | import random as rand |
| 2 | +from contextlib import nullcontext |
3 | 3 |
|
4 | 4 | import torch |
5 | | -from pytest import fixture, mark |
| 5 | +from device import DEVICE |
| 6 | +from pytest import RaisesExc, fixture, mark |
| 7 | +from torch import Tensor |
| 8 | +from utils.architectures import ModuleFactory |
6 | 9 |
|
7 | | -try: |
8 | | - _device_str = os.environ["PYTEST_TORCH_DEVICE"] |
9 | | -except KeyError: |
10 | | - _device_str = "cpu" # Default to cpu if environment variable not set |
11 | | - |
12 | | -if _device_str != "cuda:0" and _device_str != "cpu": |
13 | | - raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}") |
14 | | - |
15 | | -if _device_str == "cuda:0" and not torch.cuda.is_available(): |
16 | | - raise ValueError('Requested device "cuda:0" but cuda is not available.') |
17 | | - |
18 | | -DEVICE = torch.device(_device_str) |
| 10 | +from torchjd.aggregation import Aggregator, Weighting |
19 | 11 |
|
20 | 12 |
|
21 | 13 | @fixture(autouse=True) |
@@ -48,3 +40,24 @@ def pytest_collection_modifyitems(config, items): |
48 | 40 | item.add_marker(skip_slow) |
49 | 41 | if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"): |
50 | 42 | item.add_marker(xfail_cuda) |
| 43 | + |
| 44 | + |
| 45 | +def pytest_make_parametrize_id(config, val, argname): |
| 46 | + MAX_SIZE = 40 |
| 47 | + optional_string = None # Returning None means using pytest's way of making the string |
| 48 | + |
| 49 | + if isinstance(val, (Aggregator, ModuleFactory, Weighting)): |
| 50 | + optional_string = str(val) |
| 51 | + elif isinstance(val, Tensor): |
| 52 | + optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor |
| 53 | + elif isinstance(val, (tuple, list, set)) and len(val) < 20: |
| 54 | + optional_string = str(val) |
| 55 | + elif isinstance(val, RaisesExc): |
| 56 | + optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions]) |
| 57 | + elif isinstance(val, nullcontext): |
| 58 | + optional_string = "does_not_raise()" |
| 59 | + |
| 60 | + if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE: |
| 61 | + optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest |
| 62 | + |
| 63 | + return optional_string |
0 commit comments