55from pytest import mark , param
66from torch import nn
77from torch .optim import SGD
8+ from torch .testing import assert_close
89from unit .conftest import DEVICE
910from utils .architectures import (
1011 AlexNet ,
5253from utils .forward_backwards import (
5354 autograd_forward_backward ,
5455 autogram_forward_backward ,
55- autojac_forward_backward ,
56+ compute_gramian_with_autograd ,
5657 make_mse_loss_fn ,
5758)
5859from utils .tensors import make_tensors
5960
60- from torchjd .aggregation import UPGrad , UPGradWeighting
61+ from torchjd .aggregation import UPGradWeighting
6162from torchjd .autogram ._engine import Engine
62- from torchjd .autojac ._transform import Diagonalize , Init , Jac , OrderedSet
63- from torchjd .autojac ._transform ._aggregate import _Matrixify
6463
6564PARAMETRIZATIONS = [
6665 (OverlyNested , 32 ),
107106
108107
109108@mark .parametrize (["architecture" , "batch_size" ], PARAMETRIZATIONS )
110- def test_equivalence_autojac_autogram (
111- architecture : type [ShapedModule ],
112- batch_size : int ,
113- ):
114- """
115- Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
116- JD steps.
117- """
118-
119- n_iter = 3
109+ def test_compute_gramian (architecture : type [ShapedModule ], batch_size : int ):
110+ """Tests that the autograd and the autogram engines compute the same gramian."""
120111
121112 input_shapes = architecture .INPUT_SHAPES
122113 output_shapes = architecture .OUTPUT_SHAPES
123114
124- weighting = UPGradWeighting ()
125- aggregator = UPGrad ()
126-
127115 torch .manual_seed (0 )
128- model_autojac = architecture ().to (device = DEVICE )
116+ model_autograd = architecture ().to (device = DEVICE )
129117 torch .manual_seed (0 )
130118 model_autogram = architecture ().to (device = DEVICE )
131119
132120 engine = Engine (model_autogram .modules ())
133- optimizer_autojac = SGD (model_autojac .parameters (), lr = 1e-7 )
134- optimizer_autogram = SGD (model_autogram .parameters (), lr = 1e-7 )
135-
136- for i in range (n_iter ):
137- inputs = make_tensors (batch_size , input_shapes )
138- targets = make_tensors (batch_size , output_shapes )
139- loss_fn = make_mse_loss_fn (targets )
140-
141- torch .random .manual_seed (0 ) # Fix randomness for random aggregators and random models
142- autojac_forward_backward (model_autojac , inputs , loss_fn , aggregator )
143- expected_grads = {
144- name : p .grad for name , p in model_autojac .named_parameters () if p .grad is not None
145- }
146-
147- torch .random .manual_seed (0 ) # Fix randomness for random weightings and random models
148- autogram_forward_backward (model_autogram , engine , weighting , inputs , loss_fn )
149- grads = {
150- name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None
151- }
152-
153- assert_tensor_dicts_are_close (grads , expected_grads )
154-
155- optimizer_autojac .step ()
156- model_autojac .zero_grad ()
157-
158- optimizer_autogram .step ()
159- model_autogram .zero_grad ()
160-
161-
162- @mark .parametrize (["architecture" , "batch_size" ], PARAMETRIZATIONS )
163- def test_autograd_while_modules_are_hooked (architecture : type [ShapedModule ], batch_size : int ):
164- """
165- Tests that the hooks added when constructing the engine do not interfere with a simple autograd
166- call.
167- """
168-
169- input_shapes = architecture .INPUT_SHAPES
170- output_shapes = architecture .OUTPUT_SHAPES
171121
172- W = UPGradWeighting ()
173- A = UPGrad ()
174- input = make_tensors (batch_size , input_shapes )
122+ inputs = make_tensors (batch_size , input_shapes )
175123 targets = make_tensors (batch_size , output_shapes )
176124 loss_fn = make_mse_loss_fn (targets )
177125
178- torch .manual_seed (0 )
179- model = architecture ().to (device = DEVICE )
180-
181- torch .manual_seed (0 ) # Fix randomness for random models
182- autojac_forward_backward (model , input , loss_fn , A )
183- autojac_grads = {
184- name : p .grad .clone () for name , p in model .named_parameters () if p .grad is not None
185- }
186- model .zero_grad ()
187-
188- torch .manual_seed (0 ) # Fix randomness for random models
189- autograd_forward_backward (model , input , loss_fn )
190- autograd_grads = {
191- name : p .grad .clone () for name , p in model .named_parameters () if p .grad is not None
192- }
193-
194- torch .manual_seed (0 )
195- model_autogram = architecture ().to (device = DEVICE )
126+ torch .random .manual_seed (0 ) # Fix randomness for random models
127+ output = model_autograd (inputs )
128+ losses = loss_fn (output )
129+ autograd_gramian = compute_gramian_with_autograd (losses , list (model_autograd .parameters ()))
196130
197- # Hook modules and verify that we're equivalent to autojac when using the engine
198- engine = Engine (model_autogram .modules ())
199- torch .manual_seed (0 ) # Fix randomness for random models
200- autogram_forward_backward (model_autogram , engine , W , input , loss_fn )
201- grads = {name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None }
202- assert_tensor_dicts_are_close (grads , autojac_grads )
203- model_autogram .zero_grad ()
131+ torch .random .manual_seed (0 ) # Fix randomness for random models
132+ output = model_autogram (inputs )
133+ losses = loss_fn (output )
134+ autogram_gramian = engine .compute_gramian (losses )
204135
205- # Verify that even with the hooked modules, autograd works normally when not using the engine.
206- # Results should be the same as a normal call to autograd, and no time should be spent computing
207- # the gramian at all.
208- torch .manual_seed (0 ) # Fix randomness for random models
209- autograd_forward_backward (model_autogram , input , loss_fn )
210- assert engine ._gramian_accumulator .gramian is None
211- grads = {name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None }
212- assert_tensor_dicts_are_close (grads , autograd_grads )
213- model_autogram .zero_grad ()
136+ assert_close (autogram_gramian , autograd_gramian , rtol = 1e-4 , atol = 1e-5 )
214137
215138
216139def _non_empty_subsets (elements : set ) -> list [set ]:
@@ -221,20 +144,15 @@ def _non_empty_subsets(elements: set) -> list[set]:
221144
222145
223146@mark .parametrize ("gramian_module_names" , _non_empty_subsets ({"fc0" , "fc1" , "fc2" , "fc3" , "fc4" }))
224- def test_partial_autogram (gramian_module_names : set [str ]):
147+ def test_compute_partial_gramian (gramian_module_names : set [str ]):
225148 """
226- Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
227- the autojac engine.
228-
229- Note that this test is a bit redundant now that we have the Engine interface, because it now
230- just compares two ways of computing the Gramian, which is independant of the idea of partial JD.
149+ Tests that the autograd and the autogram engines compute the same gramian when only a subset of
150+ the model parameters is specified.
231151 """
232152
233153 architecture = SimpleBranched
234154 batch_size = 64
235155
236- weighting = UPGradWeighting ()
237-
238156 input_shapes = architecture .INPUT_SHAPES
239157 output_shapes = architecture .OUTPUT_SHAPES
240158
@@ -247,39 +165,91 @@ def test_partial_autogram(gramian_module_names: set[str]):
247165
248166 output = model (input )
249167 losses = loss_fn (output )
250- losses_ = OrderedSet (losses )
251-
252- init = Init (losses_ )
253- diag = Diagonalize (losses_ )
254168
255169 gramian_modules = [model .get_submodule (name ) for name in gramian_module_names ]
256- gramian_params = OrderedSet ({})
170+ gramian_params = []
257171 for m in gramian_modules :
258- gramian_params += OrderedSet (m .parameters ())
172+ gramian_params += list (m .parameters ())
259173
260- jac = Jac (losses_ , OrderedSet (gramian_params ), None , True )
261- mat = _Matrixify ()
262- transform = mat << jac << diag << init
263-
264- jacobian_matrices = transform ({})
265- jacobian_matrix = torch .cat (list (jacobian_matrices .values ()), dim = 1 )
266- gramian = jacobian_matrix @ jacobian_matrix .T
174+ autograd_gramian = compute_gramian_with_autograd (losses , gramian_params , retain_graph = True )
267175 torch .manual_seed (0 )
268- losses .backward (weighting (gramian ))
269-
270- expected_grads = {name : p .grad for name , p in model .named_parameters () if p .grad is not None }
271- model .zero_grad ()
272176
273177 engine = Engine (gramian_modules )
274178
275179 output = model (input )
276180 losses = loss_fn (output )
277181 gramian = engine .compute_gramian (losses )
182+
183+ assert_close (gramian , autograd_gramian )
184+
185+
186+ @mark .parametrize (["architecture" , "batch_size" ], PARAMETRIZATIONS )
187+ def test_iwrm_steps_with_autogram (architecture : type [ShapedModule ], batch_size : int ):
188+ """Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
189+
190+ n_iter = 3
191+
192+ input_shapes = architecture .INPUT_SHAPES
193+ output_shapes = architecture .OUTPUT_SHAPES
194+
195+ weighting = UPGradWeighting ()
196+
197+ model = architecture ().to (device = DEVICE )
198+
199+ engine = Engine (model .modules ())
200+ optimizer = SGD (model .parameters (), lr = 1e-7 )
201+
202+ for i in range (n_iter ):
203+ inputs = make_tensors (batch_size , input_shapes )
204+ targets = make_tensors (batch_size , output_shapes )
205+ loss_fn = make_mse_loss_fn (targets )
206+
207+ autogram_forward_backward (model , engine , weighting , inputs , loss_fn )
208+
209+ optimizer .step ()
210+ model .zero_grad ()
211+
212+
213+ @mark .parametrize (["architecture" , "batch_size" ], PARAMETRIZATIONS )
214+ @mark .parametrize ("compute_gramian" , [False , True ])
215+ def test_autograd_while_modules_are_hooked (
216+ architecture : type [ShapedModule ], batch_size : int , compute_gramian : bool
217+ ):
218+ """
219+ Tests that the hooks added when constructing the engine do not interfere with a simple autograd
220+ call.
221+ """
222+
223+ input = make_tensors (batch_size , architecture .INPUT_SHAPES )
224+ targets = make_tensors (batch_size , architecture .OUTPUT_SHAPES )
225+ loss_fn = make_mse_loss_fn (targets )
226+
227+ torch .manual_seed (0 )
228+ model = architecture ().to (device = DEVICE )
278229 torch .manual_seed (0 )
279- losses .backward (weighting (gramian ))
230+ model_autogram = architecture ().to (device = DEVICE )
231+
232+ torch .manual_seed (0 ) # Fix randomness for random models
233+ autograd_forward_backward (model , input , loss_fn )
234+ autograd_grads = {name : p .grad for name , p in model .named_parameters () if p .grad is not None }
235+
236+ # Hook modules and optionally compute the Gramian
237+ engine = Engine (model_autogram .modules ())
238+ if compute_gramian :
239+ torch .manual_seed (0 ) # Fix randomness for random models
240+ output = model_autogram (input )
241+ losses = loss_fn (output )
242+ _ = engine .compute_gramian (losses )
280243
281- grads = {name : p .grad for name , p in model .named_parameters () if p .grad is not None }
282- assert_tensor_dicts_are_close (grads , expected_grads )
244+ # Verify that even with the hooked modules, autograd works normally when not using the engine.
245+ # Results should be the same as a normal call to autograd, and no time should be spent computing
246+ # the gramian at all.
247+ torch .manual_seed (0 ) # Fix randomness for random models
248+ autograd_forward_backward (model_autogram , input , loss_fn )
249+ grads = {name : p .grad for name , p in model_autogram .named_parameters () if p .grad is not None }
250+
251+ assert_tensor_dicts_are_close (grads , autograd_grads )
252+ assert engine ._gramian_accumulator .gramian is None
283253
284254
285255@mark .parametrize ("architecture" , [WithRNN , WithModuleTrackingRunningStats ])
0 commit comments