Skip to content

Commit f579ce7

Browse files
committed
Renames, documentation and style in test_engine.py
1 parent 2ce3fb7 commit f579ce7

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,22 @@
145145
except ImportError:
146146
pass
147147

148-
WEIGHTINGS = [pair[1] for pair in AGGREGATORS_AND_WEIGHTINGS]
148+
WEIGHTINGS = [weighting for _, weighting in AGGREGATORS_AND_WEIGHTINGS]
149149

150150

151151
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
152152
@mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS)
153-
def test_equivalence(
153+
def test_equivalence_autojac_autogram(
154154
architecture: type[ShapedModule],
155155
batch_size: int,
156156
aggregator: Aggregator,
157157
weighting: Weighting,
158158
):
159+
"""
160+
Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
161+
JD steps.
162+
"""
163+
159164
n_iter = 3
160165

161166
input_shapes = architecture.INPUT_SHAPES
@@ -198,6 +203,11 @@ def test_equivalence(
198203

199204
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
200205
def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int):
206+
"""
207+
Tests that the hooks added when constructing the engine do not interfere with a simple autograd
208+
call.
209+
"""
210+
201211
input_shapes = architecture.INPUT_SHAPES
202212
output_shapes = architecture.OUTPUT_SHAPES
203213

@@ -315,6 +325,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
315325

316326
@mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats])
317327
def test_incompatible_modules(architecture: type[nn.Module]):
328+
"""Tests that the engine cannot be constructed with incompatible modules."""
329+
318330
model = architecture().to(device=DEVICE)
319331

320332
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)