Skip to content

Commit 7f5c097

Browse files
authored
Merge branch 'main' into block-diagonal-tensor
2 parents c5f868c + 4676f11 commit 7f5c097

File tree

6 files changed

+222
-143
lines changed

6 files changed

+222
-143
lines changed

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
NoFreeParam,
1515
SqueezeNet,
1616
WithTransformerLarge,
17-
get_in_out_shapes,
1817
)
1918
from utils.forward_backwards import (
2019
autograd_forward_backward,
@@ -23,7 +22,7 @@
2322
autojac_forward_backward,
2423
make_mse_loss_fn,
2524
)
26-
from utils.tensors import make_tensors
25+
from utils.tensors import make_inputs_and_targets
2726

2827
from torchjd.aggregation import Mean
2928
from torchjd.autogram import Engine
@@ -43,9 +42,7 @@
4342

4443
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
4544
model = factory()
46-
input_shapes, output_shapes = get_in_out_shapes(model)
47-
inputs = make_tensors(batch_size, input_shapes)
48-
targets = make_tensors(batch_size, output_shapes)
45+
inputs, targets = make_inputs_and_targets(model, batch_size)
4946
loss_fn = make_mse_loss_fn(targets)
5047

5148
A = Mean()
@@ -64,7 +61,7 @@ def init_fn_autograd():
6461
fn_autograd()
6562

6663
def fn_autograd_gramian():
67-
autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W)
64+
autograd_gramian_forward_backward(model, inputs, loss_fn, W)
6865

6966
def init_fn_autograd_gramian():
7067
torch.cuda.empty_cache()
@@ -80,7 +77,7 @@ def init_fn_autojac():
8077
fn_autojac()
8178

8279
def fn_autogram():
83-
autogram_forward_backward(model, engine, W, inputs, loss_fn)
80+
autogram_forward_backward(model, inputs, loss_fn, engine, W)
8481

8582
def init_fn_autogram():
8683
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)