|
145 | 145 | except ImportError: |
146 | 146 | pass |
147 | 147 |
|
148 | | -WEIGHTINGS = [pair[1] for pair in AGGREGATORS_AND_WEIGHTINGS] |
| 148 | +WEIGHTINGS = [weighting for _, weighting in AGGREGATORS_AND_WEIGHTINGS] |
149 | 149 |
|
150 | 150 |
|
151 | 151 | @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) |
152 | 152 | @mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS) |
153 | | -def test_equivalence( |
| 153 | +def test_equivalence_autojac_autogram( |
154 | 154 | architecture: type[ShapedModule], |
155 | 155 | batch_size: int, |
156 | 156 | aggregator: Aggregator, |
157 | 157 | weighting: Weighting, |
158 | 158 | ): |
| 159 | + """ |
| 160 | + Tests that the autogram engine gives the same results as the autojac engine on IWRM for several |
| 161 | + JD steps. |
| 162 | + """ |
| 163 | + |
159 | 164 | n_iter = 3 |
160 | 165 |
|
161 | 166 | input_shapes = architecture.INPUT_SHAPES |
@@ -198,6 +203,11 @@ def test_equivalence( |
198 | 203 |
|
199 | 204 | @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) |
200 | 205 | def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int): |
| 206 | + """ |
| 207 | + Tests that the hooks added when constructing the engine do not interfere with a simple autograd |
| 208 | + call. |
| 209 | + """ |
| 210 | + |
201 | 211 | input_shapes = architecture.INPUT_SHAPES |
202 | 212 | output_shapes = architecture.OUTPUT_SHAPES |
203 | 213 |
|
@@ -315,6 +325,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): |
315 | 325 |
|
316 | 326 | @mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats]) |
317 | 327 | def test_incompatible_modules(architecture: type[nn.Module]): |
| 328 | + """Tests that the engine cannot be constructed with incompatible modules.""" |
| 329 | + |
318 | 330 | model = architecture().to(device=DEVICE) |
319 | 331 |
|
320 | 332 | with pytest.raises(ValueError): |
|
0 commit comments