Skip to content

Commit 568bea0

Browse files
authored
Merge branch 'main' into linear-gramian-computer
2 parents ac384a0 + aff0abc commit 568bea0

File tree

11 files changed

+224
-225
lines changed

11 files changed

+224
-225
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,19 @@ from utils.tensors import ones_
130130
a = ones_(3, 4)
131131
```
132132
133-
This will automatically call `torch.ones` with `device=unit.conftest.DEVICE`.
133+
This will automatically call `torch.ones` with `device=DEVICE`.
134134
If the function you need does not exist yet as a partial function in `tensors.py`, add it.
135135
Lastly, when you create a model or a random generator, you have to move them manually to the right
136-
device (the `DEVICE` defined in `unit.conftest`):
136+
device (the `DEVICE` defined in `device.py`).
137137
```python
138138
import torch
139139
from torch.nn import Linear
140-
from unit.conftest import DEVICE
140+
from device import DEVICE
141141
142142
model = Linear(3, 4).to(device=DEVICE)
143143
rng = torch.Generator(device=DEVICE)
144144
```
145+
You may also use a `ModuleFactory` to make the modules on `DEVICE` automatically.
145146
146147
### Coding
147148

tests/unit/conftest.py renamed to tests/conftest.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
1-
import os
21
import random as rand
2+
from contextlib import nullcontext
33

44
import torch
5-
from pytest import fixture, mark
5+
from device import DEVICE
6+
from pytest import RaisesExc, fixture, mark
7+
from torch import Tensor
8+
from utils.architectures import ModuleFactory
69

7-
try:
8-
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
9-
except KeyError:
10-
_device_str = "cpu" # Default to cpu if environment variable not set
11-
12-
if _device_str != "cuda:0" and _device_str != "cpu":
13-
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
14-
15-
if _device_str == "cuda:0" and not torch.cuda.is_available():
16-
raise ValueError('Requested device "cuda:0" but cuda is not available.')
17-
18-
DEVICE = torch.device(_device_str)
10+
from torchjd.aggregation import Aggregator, Weighting
1911

2012

2113
@fixture(autouse=True)
@@ -48,3 +40,24 @@ def pytest_collection_modifyitems(config, items):
4840
item.add_marker(skip_slow)
4941
if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"):
5042
item.add_marker(xfail_cuda)
43+
44+
45+
def pytest_make_parametrize_id(config, val, argname):
46+
MAX_SIZE = 40
47+
optional_string = None # Returning None means using pytest's way of making the string
48+
49+
if isinstance(val, (Aggregator, ModuleFactory, Weighting)):
50+
optional_string = str(val)
51+
elif isinstance(val, Tensor):
52+
optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor
53+
elif isinstance(val, (tuple, list, set)) and len(val) < 20:
54+
optional_string = str(val)
55+
elif isinstance(val, RaisesExc):
56+
optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions])
57+
elif isinstance(val, nullcontext):
58+
optional_string = "does_not_raise()"
59+
60+
if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE:
61+
optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest
62+
63+
return optional_string

tests/device.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
import torch
4+
5+
try:
6+
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
7+
except KeyError:
8+
_device_str = "cpu" # Default to cpu if environment variable not set
9+
10+
if _device_str != "cuda:0" and _device_str != "cpu":
11+
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
12+
13+
if _device_str == "cuda:0" and not torch.cuda.is_available():
14+
raise ValueError('Requested device "cuda:0" but cuda is not available.')
15+
16+
DEVICE = torch.device(_device_str)

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
import time
33

44
import torch
5-
from unit.conftest import DEVICE
5+
from device import DEVICE
66
from utils.architectures import (
77
AlexNet,
88
Cifar10Model,
99
FreeParam,
1010
GroupNormMobileNetV3Small,
1111
InstanceNormMobileNetV2,
1212
InstanceNormResNet18,
13+
ModuleFactory,
1314
NoFreeParam,
14-
ShapedModule,
1515
SqueezeNet,
1616
WithTransformerLarge,
17+
get_in_out_shapes,
1718
)
1819
from utils.forward_backwards import (
1920
autograd_forward_backward,
@@ -28,33 +29,30 @@
2829
from torchjd.autogram import Engine
2930

3031
PARAMETRIZATIONS = [
31-
(WithTransformerLarge, 8),
32-
(FreeParam, 64),
33-
(NoFreeParam, 64),
34-
(Cifar10Model, 64),
35-
(AlexNet, 8),
36-
(InstanceNormResNet18, 16),
37-
(GroupNormMobileNetV3Small, 16),
38-
(SqueezeNet, 4),
39-
(InstanceNormMobileNetV2, 2),
32+
(ModuleFactory(WithTransformerLarge), 8),
33+
(ModuleFactory(FreeParam), 64),
34+
(ModuleFactory(NoFreeParam), 64),
35+
(ModuleFactory(Cifar10Model), 64),
36+
(ModuleFactory(AlexNet), 8),
37+
(ModuleFactory(InstanceNormResNet18), 16),
38+
(ModuleFactory(GroupNormMobileNetV3Small), 16),
39+
(ModuleFactory(SqueezeNet), 4),
40+
(ModuleFactory(InstanceNormMobileNetV2), 2),
4041
]
4142

4243

43-
def compare_autograd_autojac_and_autogram_speed(architecture: type[ShapedModule], batch_size: int):
44-
input_shapes = architecture.INPUT_SHAPES
45-
output_shapes = architecture.OUTPUT_SHAPES
44+
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
45+
model = factory()
46+
input_shapes, output_shapes = get_in_out_shapes(model)
4647
inputs = make_tensors(batch_size, input_shapes)
4748
targets = make_tensors(batch_size, output_shapes)
4849
loss_fn = make_mse_loss_fn(targets)
4950

50-
model = architecture().to(device=DEVICE)
51-
5251
A = Mean()
5352
W = A.weighting
5453

5554
print(
56-
f"\nTimes for forward + backward on {architecture.__name__} with BS={batch_size}, A={A}"
57-
f" on {DEVICE}."
55+
f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}."
5856
)
5957

6058
def fn_autograd():
@@ -148,8 +146,8 @@ def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) ->
148146

149147

150148
def main():
151-
for architecture, batch_size in PARAMETRIZATIONS:
152-
compare_autograd_autojac_and_autogram_speed(architecture, batch_size)
149+
for factory, batch_size in PARAMETRIZATIONS:
150+
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
153151
print("\n")
154152

155153

tests/unit/aggregation/_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from unit.conftest import DEVICE
2+
from device import DEVICE
33
from utils.tensors import zeros_
44

55
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler

0 commit comments

Comments
 (0)