Skip to content

Commit d8c54e4

Browse files
committed
test: Add ModuleFactory (#459)
* Add ModuleFactory and use it to instantiate models in tests * Add get_in_out_shapes and use it to obtain input and output shapes in tests
1 parent ac384a0 commit d8c54e4

File tree

4 files changed

+156
-157
lines changed

4 files changed

+156
-157
lines changed

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
GroupNormMobileNetV3Small,
1111
InstanceNormMobileNetV2,
1212
InstanceNormResNet18,
13+
ModuleFactory,
1314
NoFreeParam,
14-
ShapedModule,
1515
SqueezeNet,
1616
WithTransformerLarge,
17+
get_in_out_shapes,
1718
)
1819
from utils.forward_backwards import (
1920
autograd_forward_backward,
@@ -28,33 +29,30 @@
2829
from torchjd.autogram import Engine
2930

3031
PARAMETRIZATIONS = [
31-
(WithTransformerLarge, 8),
32-
(FreeParam, 64),
33-
(NoFreeParam, 64),
34-
(Cifar10Model, 64),
35-
(AlexNet, 8),
36-
(InstanceNormResNet18, 16),
37-
(GroupNormMobileNetV3Small, 16),
38-
(SqueezeNet, 4),
39-
(InstanceNormMobileNetV2, 2),
32+
(ModuleFactory(WithTransformerLarge), 8),
33+
(ModuleFactory(FreeParam), 64),
34+
(ModuleFactory(NoFreeParam), 64),
35+
(ModuleFactory(Cifar10Model), 64),
36+
(ModuleFactory(AlexNet), 8),
37+
(ModuleFactory(InstanceNormResNet18), 16),
38+
(ModuleFactory(GroupNormMobileNetV3Small), 16),
39+
(ModuleFactory(SqueezeNet), 4),
40+
(ModuleFactory(InstanceNormMobileNetV2), 2),
4041
]
4142

4243

43-
def compare_autograd_autojac_and_autogram_speed(architecture: type[ShapedModule], batch_size: int):
44-
input_shapes = architecture.INPUT_SHAPES
45-
output_shapes = architecture.OUTPUT_SHAPES
44+
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
45+
model = factory()
46+
input_shapes, output_shapes = get_in_out_shapes(model)
4647
inputs = make_tensors(batch_size, input_shapes)
4748
targets = make_tensors(batch_size, output_shapes)
4849
loss_fn = make_mse_loss_fn(targets)
4950

50-
model = architecture().to(device=DEVICE)
51-
5251
A = Mean()
5352
W = A.weighting
5453

5554
print(
56-
f"\nTimes for forward + backward on {architecture.__name__} with BS={batch_size}, A={A}"
57-
f" on {DEVICE}."
55+
f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}."
5856
)
5957

6058
def fn_autograd():
@@ -148,8 +146,8 @@ def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) ->
148146

149147

150148
def main():
151-
for architecture, batch_size in PARAMETRIZATIONS:
152-
compare_autograd_autojac_and_autogram_speed(architecture, batch_size)
149+
for factory, batch_size in PARAMETRIZATIONS:
150+
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
153151
print("\n")
154152

155153

0 commit comments

Comments
 (0)