Skip to content

Commit 290a393

Browse files
authored
test: Improve test ids (#460)
* Move conftest.py from tests.unit to tests * Separate DEVICE creation from conftest.py to device.py * Add pytest_make_parametrize_id
1 parent f2535dc commit 290a393

File tree

9 files changed

+53
-23
lines changed

9 files changed

+53
-23
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,19 @@ from utils.tensors import ones_
130130
a = ones_(3, 4)
131131
```
132132
133-
This will automatically call `torch.ones` with `device=unit.conftest.DEVICE`.
133+
This will automatically call `torch.ones` with `device=DEVICE`.
134134
If the function you need does not exist yet as a partial function in `tensors.py`, add it.
135135
Lastly, when you create a model or a random generator, you have to move them manually to the right
136-
device (the `DEVICE` defined in `unit.conftest`):
136+
device (the `DEVICE` defined in `device.py`).
137137
```python
138138
import torch
139139
from torch.nn import Linear
140-
from unit.conftest import DEVICE
140+
from device import DEVICE
141141
142142
model = Linear(3, 4).to(device=DEVICE)
143143
rng = torch.Generator(device=DEVICE)
144144
```
145+
You may also use a `ModuleFactory` to make the modules on `DEVICE` automatically.
145146
146147
### Coding
147148

tests/unit/conftest.py renamed to tests/conftest.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
1-
import os
21
import random as rand
2+
from contextlib import nullcontext
33

44
import torch
5-
from pytest import fixture, mark
5+
from device import DEVICE
6+
from pytest import RaisesExc, fixture, mark
7+
from torch import Tensor
8+
from utils.architectures import ModuleFactory
69

7-
try:
8-
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
9-
except KeyError:
10-
_device_str = "cpu" # Default to cpu if environment variable not set
11-
12-
if _device_str != "cuda:0" and _device_str != "cpu":
13-
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
14-
15-
if _device_str == "cuda:0" and not torch.cuda.is_available():
16-
raise ValueError('Requested device "cuda:0" but cuda is not available.')
17-
18-
DEVICE = torch.device(_device_str)
10+
from torchjd.aggregation import Aggregator, Weighting
1911

2012

2113
@fixture(autouse=True)
@@ -48,3 +40,24 @@ def pytest_collection_modifyitems(config, items):
4840
item.add_marker(skip_slow)
4941
if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"):
5042
item.add_marker(xfail_cuda)
43+
44+
45+
def pytest_make_parametrize_id(config, val, argname):
46+
MAX_SIZE = 40
47+
optional_string = None # Returning None means using pytest's way of making the string
48+
49+
if isinstance(val, (Aggregator, ModuleFactory, Weighting)):
50+
optional_string = str(val)
51+
elif isinstance(val, Tensor):
52+
optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor
53+
elif isinstance(val, (tuple, list, set)) and len(val) < 20:
54+
optional_string = str(val)
55+
elif isinstance(val, RaisesExc):
56+
optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions])
57+
elif isinstance(val, nullcontext):
58+
optional_string = "does_not_raise()"
59+
60+
if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE:
61+
optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest
62+
63+
return optional_string

tests/device.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
import torch
4+
5+
try:
6+
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
7+
except KeyError:
8+
_device_str = "cpu" # Default to cpu if environment variable not set
9+
10+
if _device_str != "cuda:0" and _device_str != "cpu":
11+
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
12+
13+
if _device_str == "cuda:0" and not torch.cuda.is_available():
14+
raise ValueError('Requested device "cuda:0" but cuda is not available.')
15+
16+
DEVICE = torch.device(_device_str)

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33

44
import torch
5-
from unit.conftest import DEVICE
5+
from device import DEVICE
66
from utils.architectures import (
77
AlexNet,
88
Cifar10Model,

tests/unit/aggregation/_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from unit.conftest import DEVICE
2+
from device import DEVICE
33
from utils.tensors import zeros_
44

55
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
22

33
import torch
4+
from device import DEVICE
45
from pytest import mark, raises
5-
from unit.conftest import DEVICE
66
from utils.dict_assertions import assert_tensor_dicts_are_close
77
from utils.tensors import rand_, tensor_, zeros_
88

tests/unit/autojac/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from device import DEVICE
12
from pytest import mark, raises
23
from torch.nn import Linear, MSELoss, ReLU, Sequential
3-
from unit.conftest import DEVICE
44
from utils.tensors import randn_, tensor_
55

66
from torchjd.autojac._utils import get_leaf_tensors

tests/utils/architectures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import torch
44
import torchvision
5+
from device import DEVICE
56
from torch import Tensor, nn
67
from torch.nn import Flatten, ReLU
78
from torch.utils._pytree import PyTree
8-
from unit.conftest import DEVICE
99

1010

1111
class ModuleFactory:

tests/utils/tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from functools import partial
22

33
import torch
4+
from device import DEVICE
45
from torch.utils._pytree import PyTree, tree_map
5-
from unit.conftest import DEVICE
66

77
# Curried calls to torch functions that require a device so that we automatically fix the device
88
# for code written in the tests, while not affecting code written in src (what

0 commit comments

Comments
 (0)