1414 NoFreeParam ,
1515 SqueezeNet ,
1616 WithTransformerLarge ,
17- get_in_out_shapes ,
1817)
1918from utils .forward_backwards import (
2019 autograd_forward_backward ,
2322 autojac_forward_backward ,
2423 make_mse_loss_fn ,
2524)
26- from utils .tensors import make_tensors
25+ from utils .tensors import make_inputs_and_targets
2726
2827from torchjd .aggregation import Mean
2928from torchjd .autogram import Engine
4342
4443def compare_autograd_autojac_and_autogram_speed (factory : ModuleFactory , batch_size : int ):
4544 model = factory ()
46- input_shapes , output_shapes = get_in_out_shapes (model )
47- inputs = make_tensors (batch_size , input_shapes )
48- targets = make_tensors (batch_size , output_shapes )
45+ inputs , targets = make_inputs_and_targets (model , batch_size )
4946 loss_fn = make_mse_loss_fn (targets )
5047
5148 A = Mean ()
@@ -64,7 +61,7 @@ def init_fn_autograd():
6461 fn_autograd ()
6562
6663 def fn_autograd_gramian ():
67- autograd_gramian_forward_backward (model , inputs , list ( model . parameters ()), loss_fn , W )
64+ autograd_gramian_forward_backward (model , inputs , loss_fn , W )
6865
6966 def init_fn_autograd_gramian ():
7067 torch .cuda .empty_cache ()
@@ -80,7 +77,7 @@ def init_fn_autojac():
8077 fn_autojac ()
8178
8279 def fn_autogram ():
83- autogram_forward_backward (model , engine , W , inputs , loss_fn )
80+ autogram_forward_backward (model , inputs , loss_fn , engine , W )
8481
8582 def init_fn_autogram ():
8683 torch .cuda .empty_cache ()
0 commit comments