Skip to content

Commit 1684cbd

Browse files
authored
test(autogram): Only test with UPGrad (#412)
1 parent d159430 commit 1684cbd

File tree

1 file changed

+7
-48
lines changed

1 file changed

+7
-48
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,7 @@
5757
)
5858
from utils.tensors import make_tensors
5959

60-
from torchjd.aggregation import (
61-
IMTLG,
62-
MGDA,
63-
Aggregator,
64-
AlignedMTL,
65-
AlignedMTLWeighting,
66-
DualProj,
67-
DualProjWeighting,
68-
IMTLGWeighting,
69-
Mean,
70-
MeanWeighting,
71-
MGDAWeighting,
72-
PCGrad,
73-
PCGradWeighting,
74-
Random,
75-
RandomWeighting,
76-
Sum,
77-
SumWeighting,
78-
UPGrad,
79-
UPGradWeighting,
80-
Weighting,
81-
)
60+
from torchjd.aggregation import UPGrad, UPGradWeighting
8261
from torchjd.autogram._engine import Engine
8362
from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet
8463
from torchjd.autojac._transform._aggregate import _Matrixify
@@ -126,35 +105,11 @@
126105
param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]),
127106
]
128107

129-
AGGREGATORS_AND_WEIGHTINGS: list[tuple[Aggregator, Weighting]] = [
130-
(UPGrad(), UPGradWeighting()),
131-
(AlignedMTL(), AlignedMTLWeighting()),
132-
(DualProj(), DualProjWeighting()),
133-
(IMTLG(), IMTLGWeighting()),
134-
(Mean(), MeanWeighting()),
135-
(MGDA(), MGDAWeighting()),
136-
(PCGrad(), PCGradWeighting()),
137-
(Random(), RandomWeighting()),
138-
(Sum(), SumWeighting()),
139-
]
140-
141-
try:
142-
from torchjd.aggregation import CAGrad, CAGradWeighting
143-
144-
AGGREGATORS_AND_WEIGHTINGS.append((CAGrad(c=0.5), CAGradWeighting(c=0.5)))
145-
except ImportError:
146-
pass
147-
148-
WEIGHTINGS = [weighting for _, weighting in AGGREGATORS_AND_WEIGHTINGS]
149-
150108

151109
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
152-
@mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS)
153110
def test_equivalence_autojac_autogram(
154111
architecture: type[ShapedModule],
155112
batch_size: int,
156-
aggregator: Aggregator,
157-
weighting: Weighting,
158113
):
159114
"""
160115
Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
@@ -166,6 +121,9 @@ def test_equivalence_autojac_autogram(
166121
input_shapes = architecture.INPUT_SHAPES
167122
output_shapes = architecture.OUTPUT_SHAPES
168123

124+
weighting = UPGradWeighting()
125+
aggregator = UPGrad()
126+
169127
torch.manual_seed(0)
170128
model_autojac = architecture().to(device=DEVICE)
171129
torch.manual_seed(0)
@@ -262,9 +220,8 @@ def _non_empty_subsets(elements: set) -> list[set]:
262220
return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)]
263221

264222

265-
@mark.parametrize("weighting", WEIGHTINGS)
266223
@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"}))
267-
def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
224+
def test_partial_autogram(gramian_module_names: set[str]):
268225
"""
269226
Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
270227
the autojac engine.
@@ -276,6 +233,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
276233
architecture = SimpleBranched
277234
batch_size = 64
278235

236+
weighting = UPGradWeighting()
237+
279238
input_shapes = architecture.INPUT_SHAPES
280239
output_shapes = architecture.OUTPUT_SHAPES
281240

0 commit comments

Comments
 (0)