5757)
5858from utils .tensors import make_tensors
5959
60- from torchjd .aggregation import (
61- IMTLG ,
62- MGDA ,
63- Aggregator ,
64- AlignedMTL ,
65- AlignedMTLWeighting ,
66- DualProj ,
67- DualProjWeighting ,
68- IMTLGWeighting ,
69- Mean ,
70- MeanWeighting ,
71- MGDAWeighting ,
72- PCGrad ,
73- PCGradWeighting ,
74- Random ,
75- RandomWeighting ,
76- Sum ,
77- SumWeighting ,
78- UPGrad ,
79- UPGradWeighting ,
80- Weighting ,
81- )
60+ from torchjd .aggregation import UPGrad , UPGradWeighting
8261from torchjd .autogram ._engine import Engine
8362from torchjd .autojac ._transform import Diagonalize , Init , Jac , OrderedSet
8463from torchjd .autojac ._transform ._aggregate import _Matrixify
126105 param (InstanceNormMobileNetV2 , 2 , marks = [mark .slow , mark .garbage_collect ]),
127106]
128107
129- AGGREGATORS_AND_WEIGHTINGS : list [tuple [Aggregator , Weighting ]] = [
130- (UPGrad (), UPGradWeighting ()),
131- (AlignedMTL (), AlignedMTLWeighting ()),
132- (DualProj (), DualProjWeighting ()),
133- (IMTLG (), IMTLGWeighting ()),
134- (Mean (), MeanWeighting ()),
135- (MGDA (), MGDAWeighting ()),
136- (PCGrad (), PCGradWeighting ()),
137- (Random (), RandomWeighting ()),
138- (Sum (), SumWeighting ()),
139- ]
140-
141- try :
142- from torchjd .aggregation import CAGrad , CAGradWeighting
143-
144- AGGREGATORS_AND_WEIGHTINGS .append ((CAGrad (c = 0.5 ), CAGradWeighting (c = 0.5 )))
145- except ImportError :
146- pass
147-
148- WEIGHTINGS = [weighting for _ , weighting in AGGREGATORS_AND_WEIGHTINGS ]
149-
150108
151109@mark .parametrize (["architecture" , "batch_size" ], PARAMETRIZATIONS )
152- @mark .parametrize (["aggregator" , "weighting" ], AGGREGATORS_AND_WEIGHTINGS )
153110def test_equivalence_autojac_autogram (
154111 architecture : type [ShapedModule ],
155112 batch_size : int ,
156- aggregator : Aggregator ,
157- weighting : Weighting ,
158113):
159114 """
160115 Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
@@ -166,6 +121,9 @@ def test_equivalence_autojac_autogram(
166121 input_shapes = architecture .INPUT_SHAPES
167122 output_shapes = architecture .OUTPUT_SHAPES
168123
124+ weighting = UPGradWeighting ()
125+ aggregator = UPGrad ()
126+
169127 torch .manual_seed (0 )
170128 model_autojac = architecture ().to (device = DEVICE )
171129 torch .manual_seed (0 )
@@ -262,9 +220,8 @@ def _non_empty_subsets(elements: set) -> list[set]:
262220 return [set (c ) for r in range (1 , len (elements ) + 1 ) for c in combinations (elements , r )]
263221
264222
265- @mark .parametrize ("weighting" , WEIGHTINGS )
266223@mark .parametrize ("gramian_module_names" , _non_empty_subsets ({"fc0" , "fc1" , "fc2" , "fc3" , "fc4" }))
267- def test_partial_autogram (weighting : Weighting , gramian_module_names : set [str ]):
224+ def test_partial_autogram (gramian_module_names : set [str ]):
268225 """
269226 Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
270227 the autojac engine.
@@ -276,6 +233,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
276233 architecture = SimpleBranched
277234 batch_size = 64
278235
236+ weighting = UPGradWeighting ()
237+
279238 input_shapes = architecture .INPUT_SHAPES
280239 output_shapes = architecture .OUTPUT_SHAPES
281240
0 commit comments