|
6 | 6 | import torch |
7 | 7 | from pytest import mark, param |
8 | 8 | from torch import Tensor |
9 | | -from torch.nn import Linear |
| 9 | +from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear |
10 | 10 | from torch.optim import SGD |
11 | 11 | from torch.testing import assert_close |
12 | 12 | from utils.architectures import ( |
|
47 | 47 | SomeUnusedOutput, |
48 | 48 | SomeUnusedParam, |
49 | 49 | SqueezeNet, |
50 | | - WithBatchNorm, |
51 | 50 | WithBuffered, |
52 | 51 | WithDropout, |
53 | | - WithModuleTrackingRunningStats, |
54 | 52 | WithModuleWithHybridPyTreeArg, |
55 | 53 | WithModuleWithHybridPyTreeKwarg, |
56 | 54 | WithModuleWithStringArg, |
57 | 55 | WithModuleWithStringKwarg, |
58 | 56 | WithModuleWithStringOutput, |
59 | 57 | WithMultiHeadAttention, |
60 | 58 | WithNoTensorOutput, |
61 | | - WithRNN, |
62 | 59 | WithSideEffect, |
63 | 60 | WithSomeFrozenModule, |
64 | 61 | WithTransformer, |
@@ -178,11 +175,14 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int |
178 | 175 | @mark.parametrize( |
179 | 176 | "factory", |
180 | 177 | [ |
181 | | - ModuleFactory(WithBatchNorm), |
| 178 | + ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), |
182 | 179 | ModuleFactory(WithSideEffect), |
183 | 180 | ModuleFactory(Randomness), |
184 | | - ModuleFactory(WithModuleTrackingRunningStats), |
185 | | - param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), |
| 181 | + ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), |
| 182 | + param( |
| 183 | + ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), |
| 184 | + marks=mark.xfail_if_cuda, |
| 185 | + ), |
186 | 186 | ], |
187 | 187 | ) |
188 | 188 | @mark.parametrize("batch_size", [1, 3, 32]) |
@@ -397,9 +397,9 @@ def test_autograd_while_modules_are_hooked( |
397 | 397 | @mark.parametrize( |
398 | 398 | ["factory", "batch_dim"], |
399 | 399 | [ |
400 | | - (ModuleFactory(WithModuleTrackingRunningStats), 0), |
401 | | - (ModuleFactory(WithRNN), 0), |
402 | | - (ModuleFactory(WithBatchNorm), 0), |
| 400 | + (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0), |
| 401 | + (ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0), |
| 402 | + (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), |
403 | 403 | ], |
404 | 404 | ) |
405 | 405 | def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): |
|
0 commit comments