|
10 | 10 | GroupNormMobileNetV3Small, |
11 | 11 | InstanceNormMobileNetV2, |
12 | 12 | InstanceNormResNet18, |
| 13 | + ModuleFactory, |
13 | 14 | NoFreeParam, |
14 | | - ShapedModule, |
15 | 15 | SqueezeNet, |
16 | 16 | WithTransformerLarge, |
| 17 | + get_in_out_shapes, |
17 | 18 | ) |
18 | 19 | from utils.forward_backwards import ( |
19 | 20 | autograd_forward_backward, |
|
28 | 29 | from torchjd.autogram import Engine |
29 | 30 |
|
30 | 31 | PARAMETRIZATIONS = [ |
31 | | - (WithTransformerLarge, 8), |
32 | | - (FreeParam, 64), |
33 | | - (NoFreeParam, 64), |
34 | | - (Cifar10Model, 64), |
35 | | - (AlexNet, 8), |
36 | | - (InstanceNormResNet18, 16), |
37 | | - (GroupNormMobileNetV3Small, 16), |
38 | | - (SqueezeNet, 4), |
39 | | - (InstanceNormMobileNetV2, 2), |
| 32 | + (ModuleFactory(WithTransformerLarge), 8), |
| 33 | + (ModuleFactory(FreeParam), 64), |
| 34 | + (ModuleFactory(NoFreeParam), 64), |
| 35 | + (ModuleFactory(Cifar10Model), 64), |
| 36 | + (ModuleFactory(AlexNet), 8), |
| 37 | + (ModuleFactory(InstanceNormResNet18), 16), |
| 38 | + (ModuleFactory(GroupNormMobileNetV3Small), 16), |
| 39 | + (ModuleFactory(SqueezeNet), 4), |
| 40 | + (ModuleFactory(InstanceNormMobileNetV2), 2), |
40 | 41 | ] |
41 | 42 |
|
42 | 43 |
|
43 | | -def compare_autograd_autojac_and_autogram_speed(architecture: type[ShapedModule], batch_size: int): |
44 | | - input_shapes = architecture.INPUT_SHAPES |
45 | | - output_shapes = architecture.OUTPUT_SHAPES |
| 44 | +def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): |
| 45 | + model = factory() |
| 46 | + input_shapes, output_shapes = get_in_out_shapes(model) |
46 | 47 | inputs = make_tensors(batch_size, input_shapes) |
47 | 48 | targets = make_tensors(batch_size, output_shapes) |
48 | 49 | loss_fn = make_mse_loss_fn(targets) |
49 | 50 |
|
50 | | - model = architecture().to(device=DEVICE) |
51 | | - |
52 | 51 | A = Mean() |
53 | 52 | W = A.weighting |
54 | 53 |
|
55 | 54 | print( |
56 | | - f"\nTimes for forward + backward on {architecture.__name__} with BS={batch_size}, A={A}" |
57 | | - f" on {DEVICE}." |
| 55 | + f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}." |
58 | 56 | ) |
59 | 57 |
|
60 | 58 | def fn_autograd(): |
@@ -148,8 +146,8 @@ def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> |
148 | 146 |
|
149 | 147 |
|
150 | 148 | def main(): |
151 | | - for architecture, batch_size in PARAMETRIZATIONS: |
152 | | - compare_autograd_autojac_and_autogram_speed(architecture, batch_size) |
| 149 | + for factory, batch_size in PARAMETRIZATIONS: |
| 150 | + compare_autograd_autojac_and_autogram_speed(factory, batch_size) |
153 | 151 | print("\n") |
154 | 152 |
|
155 | 153 |
|
|
0 commit comments