6161 WithSomeFrozenModule ,
6262 WithTransformer ,
6363 WithTransformerLarge ,
64- get_in_out_shapes ,
6564)
6665from utils .dict_assertions import assert_tensor_dicts_are_close
6766from utils .forward_backwards import (
6867 autograd_forward_backward ,
6968 autogram_forward_backward ,
7069 compute_gramian ,
7170 compute_gramian_with_autograd ,
71+ forward_pass ,
7272 make_mse_loss_fn ,
7373 reduce_to_first_tensor ,
7474 reduce_to_matrix ,
7575 reduce_to_scalar ,
7676 reduce_to_vector ,
7777)
78- from utils .tensors import make_tensors , ones_ , randn_ , zeros_
78+ from utils .tensors import make_inputs_and_targets , ones_ , randn_ , zeros_
7979
8080from torchjd .aggregation import UPGradWeighting
8181from torchjd .autogram ._engine import Engine
@@ -144,22 +144,14 @@ def _assert_gramian_is_equivalent_to_autograd(
144144 factory : ModuleFactory , batch_size : int , batch_dim : int | None
145145):
146146 model_autograd , model_autogram = factory (), factory ()
147- input_shapes , output_shapes = get_in_out_shapes (model_autograd )
148-
149147 engine = Engine (model_autogram , batch_dim = batch_dim )
150-
151- inputs = make_tensors (batch_size , input_shapes )
152- targets = make_tensors (batch_size , output_shapes )
148+ inputs , targets = make_inputs_and_targets (model_autograd , batch_size )
153149 loss_fn = make_mse_loss_fn (targets )
154150
155- torch .random .manual_seed (0 ) # Fix randomness for random models
156- output = model_autograd (inputs )
157- losses = reduce_to_vector (loss_fn (output ))
151+ losses = forward_pass (model_autograd , inputs , loss_fn , reduce_to_vector )
158152 autograd_gramian = compute_gramian_with_autograd (losses , list (model_autograd .parameters ()))
159153
160- torch .random .manual_seed (0 ) # Fix randomness for random models
161- output = model_autogram (inputs )
162- losses = reduce_to_vector (loss_fn (output ))
154+ losses = forward_pass (model_autogram , inputs , loss_fn , reduce_to_vector )
163155 autogram_gramian = engine .compute_gramian (losses )
164156
165157 assert_close (autogram_gramian , autograd_gramian , rtol = 1e-4 , atol = 3e-5 )
@@ -255,26 +247,18 @@ def test_compute_gramian_various_output_shapes(
255247
256248 factory = ModuleFactory (Ndim2Output )
257249 model_autograd , model_autogram = factory (), factory ()
258- input_shapes , output_shapes = get_in_out_shapes (model_autograd )
259-
260- engine = Engine (model_autogram , batch_dim = batch_dim )
261-
262- inputs = make_tensors (batch_size , input_shapes )
263- targets = make_tensors (batch_size , output_shapes )
250+ inputs , targets = make_inputs_and_targets (model_autograd , batch_size )
264251 loss_fn = make_mse_loss_fn (targets )
265252
266- torch .random .manual_seed (0 ) # Fix randomness for random models
267- output = model_autograd (inputs )
268- losses = reduction (loss_fn (output ))
253+ losses = forward_pass (model_autograd , inputs , loss_fn , reduction )
269254 reshaped_losses = torch .movedim (losses , movedim_source , movedim_destination )
270255 # Go back to a vector so that compute_gramian_with_autograd works
271256 loss_vector = reshaped_losses .reshape ([- 1 ])
272257 autograd_gramian = compute_gramian_with_autograd (loss_vector , list (model_autograd .parameters ()))
273258 expected_gramian = reshape_gramian (autograd_gramian , list (reshaped_losses .shape ))
274259
275- torch .random .manual_seed (0 ) # Fix randomness for random models
276- output = model_autogram (inputs )
277- losses = reduction (loss_fn (output ))
260+ engine = Engine (model_autogram , batch_dim = batch_dim )
261+ losses = forward_pass (model_autogram , inputs , loss_fn , reduction )
278262 reshaped_losses = torch .movedim (losses , movedim_source , movedim_destination )
279263 autogram_gramian = engine .compute_gramian (reshaped_losses )
280264
@@ -296,30 +280,20 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
296280 the model parameters is specified.
297281 """
298282
299- factory = ModuleFactory (SimpleBranched )
300- model = factory ()
301- input_shapes , output_shapes = get_in_out_shapes (model )
283+ model = SimpleBranched ()
302284 batch_size = 64
303-
304- input = make_tensors (batch_size , input_shapes )
305- targets = make_tensors (batch_size , output_shapes )
285+ inputs , targets = make_inputs_and_targets (model , batch_size )
306286 loss_fn = make_mse_loss_fn (targets )
307-
308- output = model (input )
309- losses = reduce_to_vector (loss_fn (output ))
310-
311287 gramian_modules = [model .get_submodule (name ) for name in gramian_module_names ]
312288 gramian_params = []
313289 for m in gramian_modules :
314290 gramian_params += list (m .parameters ())
315291
292+ losses = forward_pass (model , inputs , loss_fn , reduce_to_vector )
316293 autograd_gramian = compute_gramian_with_autograd (losses , gramian_params , retain_graph = True )
317- torch .manual_seed (0 )
318294
319295 engine = Engine (* gramian_modules , batch_dim = batch_dim )
320-
321- output = model (input )
322- losses = reduce_to_vector (loss_fn (output ))
296+ losses = forward_pass (model , inputs , loss_fn , reduce_to_vector )
323297 gramian = engine .compute_gramian (losses )
324298
325299 assert_close (gramian , autograd_gramian )
@@ -331,22 +305,15 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch
331305 """Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
332306
333307 n_iter = 3
334-
335308 model = factory ()
336- input_shapes , output_shapes = get_in_out_shapes (model )
337-
338309 weighting = UPGradWeighting ()
339-
340310 engine = Engine (model , batch_dim = batch_dim )
341311 optimizer = SGD (model .parameters (), lr = 1e-7 )
342312
343313 for i in range (n_iter ):
344- inputs = make_tensors (batch_size , input_shapes )
345- targets = make_tensors (batch_size , output_shapes )
314+ inputs , targets = make_inputs_and_targets (model , batch_size )
346315 loss_fn = make_mse_loss_fn (targets )
347-
348- autogram_forward_backward (model , engine , weighting , inputs , loss_fn )
349-
316+ autogram_forward_backward (model , inputs , loss_fn , engine , weighting )
350317 optimizer .step ()
351318 model .zero_grad ()
352319
@@ -363,29 +330,22 @@ def test_autograd_while_modules_are_hooked(
363330 """
364331
365332 model , model_autogram = factory (), factory ()
366- input_shapes , output_shapes = get_in_out_shapes (model )
367-
368- input = make_tensors (batch_size , input_shapes )
369- targets = make_tensors (batch_size , output_shapes )
333+ inputs , targets = make_inputs_and_targets (model , batch_size )
370334 loss_fn = make_mse_loss_fn (targets )
371335
372- torch .manual_seed (0 ) # Fix randomness for random models
373- autograd_forward_backward (model , input , loss_fn )
336+ autograd_forward_backward (model , inputs , loss_fn )
374337 autograd_grads = {name : p .grad for name , p in model .named_parameters () if p .grad is not None }
375338
376339 # Hook modules and optionally compute the Gramian
377340 engine = Engine (model_autogram , batch_dim = batch_dim )
378341 if use_engine :
379- torch .manual_seed (0 ) # Fix randomness for random models
380- output = model_autogram (input )
381- losses = reduce_to_vector (loss_fn (output ))
342+ losses = forward_pass (model_autogram , inputs , loss_fn , reduce_to_vector )
382343 _ = engine .compute_gramian (losses )
383344
384345 # Verify that even with the hooked modules, autograd works normally when not using the engine.
385346 # Results should be the same as a normal call to autograd, and no time should be spent computing
386347 # the gramian at all.
387- torch .manual_seed (0 ) # Fix randomness for random models
388- autograd_forward_backward (model_autogram , input , loss_fn )
348+ autograd_forward_backward (model_autogram , inputs , loss_fn )
389349 grads = {name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None }
390350
391351 assert_tensor_dicts_are_close (grads , autograd_grads )
@@ -416,12 +376,11 @@ def test_compute_gramian_manual():
416376
417377 in_dims = 18
418378 out_dims = 25
419-
420379 factory = ModuleFactory (Linear , in_dims , out_dims )
421380 model = factory ()
422- engine = Engine (model , batch_dim = None )
423-
424381 input = randn_ (in_dims )
382+
383+ engine = Engine (model , batch_dim = None )
425384 output = model (input )
426385 gramian = engine .compute_gramian (output )
427386
@@ -462,21 +421,19 @@ def test_reshape_equivariance(shape: list[int]):
462421
463422 input_size = shape [0 ]
464423 output_size = prod (shape [1 :])
465-
466424 factory = ModuleFactory (Linear , input_size , output_size )
467425 model1 , model2 = factory (), factory ()
426+ input = randn_ ([input_size ])
468427
469428 engine1 = Engine (model1 , batch_dim = None )
470- engine2 = Engine (model2 , batch_dim = None )
471-
472- input = randn_ ([input_size ])
473429 output = model1 (input )
474- reshaped_output = model2 (input ).reshape (shape [1 :])
475-
476430 gramian = engine1 .compute_gramian (output )
477- reshaped_gramian = engine2 .compute_gramian (reshaped_output )
478431 expected_reshaped_gramian = reshape_gramian (gramian , shape [1 :])
479432
433+ engine2 = Engine (model2 , batch_dim = None )
434+ reshaped_output = model2 (input ).reshape (shape [1 :])
435+ reshaped_gramian = engine2 .compute_gramian (reshaped_output )
436+
480437 assert_close (reshaped_gramian , expected_reshaped_gramian )
481438
482439
@@ -502,21 +459,19 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
502459
503460 input_size = shape [0 ]
504461 output_size = prod (shape [1 :])
505-
506462 factory = ModuleFactory (Linear , input_size , output_size )
507463 model1 , model2 = factory (), factory ()
464+ input = randn_ ([input_size ])
508465
509466 engine1 = Engine (model1 , batch_dim = None )
510- engine2 = Engine (model2 , batch_dim = None )
511-
512- input = randn_ ([input_size ])
513467 output = model1 (input ).reshape (shape [1 :])
514- moved_output = model2 (input ).reshape (shape [1 :]).movedim (source , destination )
515-
516468 gramian = engine1 .compute_gramian (output )
517- moved_gramian = engine2 .compute_gramian (moved_output )
518469 expected_moved_gramian = movedim_gramian (gramian , source , destination )
519470
471+ engine2 = Engine (model2 , batch_dim = None )
472+ moved_output = model2 (input ).reshape (shape [1 :]).movedim (source , destination )
473+ moved_gramian = engine2 .compute_gramian (moved_output )
474+
520475 assert_close (moved_gramian , expected_moved_gramian )
521476
522477
@@ -545,18 +500,16 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
545500 input_size = prod (non_batched_shape )
546501 batch_size = shape [batch_dim ]
547502 output_size = input_size
548-
549503 factory = ModuleFactory (Linear , input_size , output_size )
550504 model1 , model2 = factory (), factory ()
505+ input = randn_ ([batch_size , input_size ])
551506
552507 engine1 = Engine (model1 , batch_dim = batch_dim )
553- engine2 = Engine (model2 , batch_dim = None )
554-
555- input = randn_ ([batch_size , input_size ])
556508 output1 = model1 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
557- output2 = model2 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
558-
559509 gramian1 = engine1 .compute_gramian (output1 )
510+
511+ engine2 = Engine (model2 , batch_dim = None )
512+ output2 = model2 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
560513 gramian2 = engine2 .compute_gramian (output2 )
561514
562515 assert_close (gramian1 , gramian2 )
@@ -573,24 +526,15 @@ def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: i
573526 """
574527
575528 model_0 , model_none = factory (), factory ()
576- input_shapes , output_shapes = get_in_out_shapes (model_0 )
577-
578- engine_0 = Engine (model_0 , batch_dim = 0 )
579- engine_none = Engine (model_none , batch_dim = None )
580-
581- inputs = make_tensors (batch_size , input_shapes )
582- targets = make_tensors (batch_size , output_shapes )
529+ inputs , targets = make_inputs_and_targets (model_0 , batch_size )
583530 loss_fn = make_mse_loss_fn (targets )
584531
585- torch .random .manual_seed (0 ) # Fix randomness for random models
586- output = model_0 (inputs )
587- losses_0 = reduce_to_vector (loss_fn (output ))
588-
589- torch .random .manual_seed (0 ) # Fix randomness for random models
590- output = model_none (inputs )
591- losses_none = reduce_to_vector (loss_fn (output ))
592-
532+ engine_0 = Engine (model_0 , batch_dim = 0 )
533+ losses_0 = forward_pass (model_0 , inputs , loss_fn , reduce_to_vector )
593534 gramian_0 = engine_0 .compute_gramian (losses_0 )
535+
536+ engine_none = Engine (model_none , batch_dim = None )
537+ losses_none = forward_pass (model_none , inputs , loss_fn , reduce_to_vector )
594538 gramian_none = engine_none .compute_gramian (losses_none )
595539
596540 assert_close (gramian_0 , gramian_none , rtol = 1e-4 , atol = 1e-5 )
0 commit comments