66import torch
77from pytest import mark , param
88from torch import Tensor
9- from torch .nn import RNN , BatchNorm2d , InstanceNorm2d , Linear
9+ from torch .nn import BatchNorm2d , InstanceNorm2d , Linear
1010from torch .optim import SGD
1111from torch .testing import assert_close
1212from utils .architectures import (
5656 WithModuleWithStringOutput ,
5757 WithMultiHeadAttention ,
5858 WithNoTensorOutput ,
59+ WithRNN ,
5960 WithSideEffect ,
6061 WithSomeFrozenModule ,
6162 WithTransformer ,
6263 WithTransformerLarge ,
63- get_in_out_shapes ,
6464)
6565from utils .dict_assertions import assert_tensor_dicts_are_close
6666from utils .forward_backwards import (
6767 autograd_forward_backward ,
6868 autogram_forward_backward ,
6969 compute_gramian ,
7070 compute_gramian_with_autograd ,
71+ forward_pass ,
7172 make_mse_loss_fn ,
7273 reduce_to_first_tensor ,
7374 reduce_to_matrix ,
7475 reduce_to_scalar ,
7576 reduce_to_vector ,
7677)
77- from utils .tensors import make_tensors , ones_ , randn_ , zeros_
78+ from utils .tensors import make_inputs_and_targets , ones_ , randn_ , zeros_
7879
7980from torchjd .aggregation import UPGradWeighting
8081from torchjd .autogram ._engine import Engine
@@ -143,22 +144,14 @@ def _assert_gramian_is_equivalent_to_autograd(
143144 factory : ModuleFactory , batch_size : int , batch_dim : int | None
144145):
145146 model_autograd , model_autogram = factory (), factory ()
146- input_shapes , output_shapes = get_in_out_shapes (model_autograd )
147-
148147 engine = Engine (model_autogram , batch_dim = batch_dim )
149-
150- inputs = make_tensors (batch_size , input_shapes )
151- targets = make_tensors (batch_size , output_shapes )
148+ inputs , targets = make_inputs_and_targets (model_autograd , batch_size )
152149 loss_fn = make_mse_loss_fn (targets )
153150
154- torch .random .manual_seed (0 ) # Fix randomness for random models
155- output = model_autograd (inputs )
156- losses = reduce_to_vector (loss_fn (output ))
151+ losses = forward_pass (model_autograd , inputs , loss_fn , reduce_to_vector )
157152 autograd_gramian = compute_gramian_with_autograd (losses , list (model_autograd .parameters ()))
158153
159- torch .random .manual_seed (0 ) # Fix randomness for random models
160- output = model_autogram (inputs )
161- losses = reduce_to_vector (loss_fn (output ))
154+ losses = forward_pass (model_autogram , inputs , loss_fn , reduce_to_vector )
162155 autogram_gramian = engine .compute_gramian (losses )
163156
164157 assert_close (autogram_gramian , autograd_gramian , rtol = 1e-4 , atol = 3e-5 )
@@ -179,10 +172,7 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int
179172 ModuleFactory (WithSideEffect ),
180173 ModuleFactory (Randomness ),
181174 ModuleFactory (InstanceNorm2d , num_features = 3 , affine = True , track_running_stats = True ),
182- param (
183- ModuleFactory (RNN , input_size = 8 , hidden_size = 5 , batch_first = True ),
184- marks = mark .xfail_if_cuda ,
185- ),
175+ param (ModuleFactory (WithRNN ), marks = mark .xfail_if_cuda ),
186176 ],
187177)
188178@mark .parametrize ("batch_size" , [1 , 3 , 32 ])
@@ -257,26 +247,18 @@ def test_compute_gramian_various_output_shapes(
257247
258248 factory = ModuleFactory (Ndim2Output )
259249 model_autograd , model_autogram = factory (), factory ()
260- input_shapes , output_shapes = get_in_out_shapes (model_autograd )
261-
262- engine = Engine (model_autogram , batch_dim = batch_dim )
263-
264- inputs = make_tensors (batch_size , input_shapes )
265- targets = make_tensors (batch_size , output_shapes )
250+ inputs , targets = make_inputs_and_targets (model_autograd , batch_size )
266251 loss_fn = make_mse_loss_fn (targets )
267252
268- torch .random .manual_seed (0 ) # Fix randomness for random models
269- output = model_autograd (inputs )
270- losses = reduction (loss_fn (output ))
253+ losses = forward_pass (model_autograd , inputs , loss_fn , reduction )
271254 reshaped_losses = torch .movedim (losses , movedim_source , movedim_destination )
272255 # Go back to a vector so that compute_gramian_with_autograd works
273256 loss_vector = reshaped_losses .reshape ([- 1 ])
274257 autograd_gramian = compute_gramian_with_autograd (loss_vector , list (model_autograd .parameters ()))
275258 expected_gramian = reshape_gramian (autograd_gramian , list (reshaped_losses .shape ))
276259
277- torch .random .manual_seed (0 ) # Fix randomness for random models
278- output = model_autogram (inputs )
279- losses = reduction (loss_fn (output ))
260+ engine = Engine (model_autogram , batch_dim = batch_dim )
261+ losses = forward_pass (model_autogram , inputs , loss_fn , reduction )
280262 reshaped_losses = torch .movedim (losses , movedim_source , movedim_destination )
281263 autogram_gramian = engine .compute_gramian (reshaped_losses )
282264
@@ -298,30 +280,20 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
298280 the model parameters is specified.
299281 """
300282
301- factory = ModuleFactory (SimpleBranched )
302- model = factory ()
303- input_shapes , output_shapes = get_in_out_shapes (model )
283+ model = SimpleBranched ()
304284 batch_size = 64
305-
306- input = make_tensors (batch_size , input_shapes )
307- targets = make_tensors (batch_size , output_shapes )
285+ inputs , targets = make_inputs_and_targets (model , batch_size )
308286 loss_fn = make_mse_loss_fn (targets )
309-
310- output = model (input )
311- losses = reduce_to_vector (loss_fn (output ))
312-
313287 gramian_modules = [model .get_submodule (name ) for name in gramian_module_names ]
314288 gramian_params = []
315289 for m in gramian_modules :
316290 gramian_params += list (m .parameters ())
317291
292+ losses = forward_pass (model , inputs , loss_fn , reduce_to_vector )
318293 autograd_gramian = compute_gramian_with_autograd (losses , gramian_params , retain_graph = True )
319- torch .manual_seed (0 )
320294
321295 engine = Engine (* gramian_modules , batch_dim = batch_dim )
322-
323- output = model (input )
324- losses = reduce_to_vector (loss_fn (output ))
296+ losses = forward_pass (model , inputs , loss_fn , reduce_to_vector )
325297 gramian = engine .compute_gramian (losses )
326298
327299 assert_close (gramian , autograd_gramian )
@@ -333,22 +305,15 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch
333305 """Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
334306
335307 n_iter = 3
336-
337308 model = factory ()
338- input_shapes , output_shapes = get_in_out_shapes (model )
339-
340309 weighting = UPGradWeighting ()
341-
342310 engine = Engine (model , batch_dim = batch_dim )
343311 optimizer = SGD (model .parameters (), lr = 1e-7 )
344312
345313 for i in range (n_iter ):
346- inputs = make_tensors (batch_size , input_shapes )
347- targets = make_tensors (batch_size , output_shapes )
314+ inputs , targets = make_inputs_and_targets (model , batch_size )
348315 loss_fn = make_mse_loss_fn (targets )
349-
350- autogram_forward_backward (model , engine , weighting , inputs , loss_fn )
351-
316+ autogram_forward_backward (model , inputs , loss_fn , engine , weighting )
352317 optimizer .step ()
353318 model .zero_grad ()
354319
@@ -365,29 +330,22 @@ def test_autograd_while_modules_are_hooked(
365330 """
366331
367332 model , model_autogram = factory (), factory ()
368- input_shapes , output_shapes = get_in_out_shapes (model )
369-
370- input = make_tensors (batch_size , input_shapes )
371- targets = make_tensors (batch_size , output_shapes )
333+ inputs , targets = make_inputs_and_targets (model , batch_size )
372334 loss_fn = make_mse_loss_fn (targets )
373335
374- torch .manual_seed (0 ) # Fix randomness for random models
375- autograd_forward_backward (model , input , loss_fn )
336+ autograd_forward_backward (model , inputs , loss_fn )
376337 autograd_grads = {name : p .grad for name , p in model .named_parameters () if p .grad is not None }
377338
378339 # Hook modules and optionally compute the Gramian
379340 engine = Engine (model_autogram , batch_dim = batch_dim )
380341 if use_engine :
381- torch .manual_seed (0 ) # Fix randomness for random models
382- output = model_autogram (input )
383- losses = reduce_to_vector (loss_fn (output ))
342+ losses = forward_pass (model_autogram , inputs , loss_fn , reduce_to_vector )
384343 _ = engine .compute_gramian (losses )
385344
386345 # Verify that even with the hooked modules, autograd works normally when not using the engine.
387346 # Results should be the same as a normal call to autograd, and no time should be spent computing
388347 # the gramian at all.
389- torch .manual_seed (0 ) # Fix randomness for random models
390- autograd_forward_backward (model_autogram , input , loss_fn )
348+ autograd_forward_backward (model_autogram , inputs , loss_fn )
391349 grads = {name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None }
392350
393351 assert_tensor_dicts_are_close (grads , autograd_grads )
@@ -398,7 +356,7 @@ def test_autograd_while_modules_are_hooked(
398356 ["factory" , "batch_dim" ],
399357 [
400358 (ModuleFactory (InstanceNorm2d , num_features = 3 , affine = True , track_running_stats = True ), 0 ),
401- (ModuleFactory (RNN , input_size = 8 , hidden_size = 5 , batch_first = True ), 0 ),
359+ param (ModuleFactory (WithRNN ), 0 ),
402360 (ModuleFactory (BatchNorm2d , num_features = 3 , affine = True , track_running_stats = False ), 0 ),
403361 ],
404362)
@@ -418,12 +376,11 @@ def test_compute_gramian_manual():
418376
419377 in_dims = 18
420378 out_dims = 25
421-
422379 factory = ModuleFactory (Linear , in_dims , out_dims )
423380 model = factory ()
424- engine = Engine (model , batch_dim = None )
425-
426381 input = randn_ (in_dims )
382+
383+ engine = Engine (model , batch_dim = None )
427384 output = model (input )
428385 gramian = engine .compute_gramian (output )
429386
@@ -464,21 +421,19 @@ def test_reshape_equivariance(shape: list[int]):
464421
465422 input_size = shape [0 ]
466423 output_size = prod (shape [1 :])
467-
468424 factory = ModuleFactory (Linear , input_size , output_size )
469425 model1 , model2 = factory (), factory ()
426+ input = randn_ ([input_size ])
470427
471428 engine1 = Engine (model1 , batch_dim = None )
472- engine2 = Engine (model2 , batch_dim = None )
473-
474- input = randn_ ([input_size ])
475429 output = model1 (input )
476- reshaped_output = model2 (input ).reshape (shape [1 :])
477-
478430 gramian = engine1 .compute_gramian (output )
479- reshaped_gramian = engine2 .compute_gramian (reshaped_output )
480431 expected_reshaped_gramian = reshape_gramian (gramian , shape [1 :])
481432
433+ engine2 = Engine (model2 , batch_dim = None )
434+ reshaped_output = model2 (input ).reshape (shape [1 :])
435+ reshaped_gramian = engine2 .compute_gramian (reshaped_output )
436+
482437 assert_close (reshaped_gramian , expected_reshaped_gramian )
483438
484439
@@ -504,21 +459,19 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
504459
505460 input_size = shape [0 ]
506461 output_size = prod (shape [1 :])
507-
508462 factory = ModuleFactory (Linear , input_size , output_size )
509463 model1 , model2 = factory (), factory ()
464+ input = randn_ ([input_size ])
510465
511466 engine1 = Engine (model1 , batch_dim = None )
512- engine2 = Engine (model2 , batch_dim = None )
513-
514- input = randn_ ([input_size ])
515467 output = model1 (input ).reshape (shape [1 :])
516- moved_output = model2 (input ).reshape (shape [1 :]).movedim (source , destination )
517-
518468 gramian = engine1 .compute_gramian (output )
519- moved_gramian = engine2 .compute_gramian (moved_output )
520469 expected_moved_gramian = movedim_gramian (gramian , source , destination )
521470
471+ engine2 = Engine (model2 , batch_dim = None )
472+ moved_output = model2 (input ).reshape (shape [1 :]).movedim (source , destination )
473+ moved_gramian = engine2 .compute_gramian (moved_output )
474+
522475 assert_close (moved_gramian , expected_moved_gramian )
523476
524477
@@ -547,18 +500,16 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
547500 input_size = prod (non_batched_shape )
548501 batch_size = shape [batch_dim ]
549502 output_size = input_size
550-
551503 factory = ModuleFactory (Linear , input_size , output_size )
552504 model1 , model2 = factory (), factory ()
505+ input = randn_ ([batch_size , input_size ])
553506
554507 engine1 = Engine (model1 , batch_dim = batch_dim )
555- engine2 = Engine (model2 , batch_dim = None )
556-
557- input = randn_ ([batch_size , input_size ])
558508 output1 = model1 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
559- output2 = model2 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
560-
561509 gramian1 = engine1 .compute_gramian (output1 )
510+
511+ engine2 = Engine (model2 , batch_dim = None )
512+ output2 = model2 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
562513 gramian2 = engine2 .compute_gramian (output2 )
563514
564515 assert_close (gramian1 , gramian2 )
@@ -575,24 +526,15 @@ def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: i
575526 """
576527
577528 model_0 , model_none = factory (), factory ()
578- input_shapes , output_shapes = get_in_out_shapes (model_0 )
579-
580- engine_0 = Engine (model_0 , batch_dim = 0 )
581- engine_none = Engine (model_none , batch_dim = None )
582-
583- inputs = make_tensors (batch_size , input_shapes )
584- targets = make_tensors (batch_size , output_shapes )
529+ inputs , targets = make_inputs_and_targets (model_0 , batch_size )
585530 loss_fn = make_mse_loss_fn (targets )
586531
587- torch .random .manual_seed (0 ) # Fix randomness for random models
588- output = model_0 (inputs )
589- losses_0 = reduce_to_vector (loss_fn (output ))
590-
591- torch .random .manual_seed (0 ) # Fix randomness for random models
592- output = model_none (inputs )
593- losses_none = reduce_to_vector (loss_fn (output ))
594-
532+ engine_0 = Engine (model_0 , batch_dim = 0 )
533+ losses_0 = forward_pass (model_0 , inputs , loss_fn , reduce_to_vector )
595534 gramian_0 = engine_0 .compute_gramian (losses_0 )
535+
536+ engine_none = Engine (model_none , batch_dim = None )
537+ losses_none = forward_pass (model_none , inputs , loss_fn , reduce_to_vector )
596538 gramian_none = engine_none .compute_gramian (losses_none )
597539
598540 assert_close (gramian_0 , gramian_none , rtol = 1e-4 , atol = 1e-5 )
0 commit comments