| 
6 | 6 | import torch  | 
7 | 7 | from pytest import mark, param  | 
8 | 8 | from torch import Tensor  | 
9 |  | -from torch.nn import Linear  | 
 | 9 | +from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear  | 
10 | 10 | from torch.optim import SGD  | 
11 | 11 | from torch.testing import assert_close  | 
12 | 12 | from utils.architectures import (  | 
 | 
47 | 47 |     SomeUnusedOutput,  | 
48 | 48 |     SomeUnusedParam,  | 
49 | 49 |     SqueezeNet,  | 
50 |  | -    WithBatchNorm,  | 
51 | 50 |     WithBuffered,  | 
52 | 51 |     WithDropout,  | 
53 |  | -    WithModuleTrackingRunningStats,  | 
54 | 52 |     WithModuleWithHybridPyTreeArg,  | 
55 | 53 |     WithModuleWithHybridPyTreeKwarg,  | 
56 | 54 |     WithModuleWithStringArg,  | 
57 | 55 |     WithModuleWithStringKwarg,  | 
58 | 56 |     WithModuleWithStringOutput,  | 
59 | 57 |     WithMultiHeadAttention,  | 
60 | 58 |     WithNoTensorOutput,  | 
61 |  | -    WithRNN,  | 
62 | 59 |     WithSideEffect,  | 
63 | 60 |     WithSomeFrozenModule,  | 
64 | 61 |     WithTransformer,  | 
@@ -178,11 +175,14 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int  | 
178 | 175 | @mark.parametrize(  | 
179 | 176 |     "factory",  | 
180 | 177 |     [  | 
181 |  | -        ModuleFactory(WithBatchNorm),  | 
 | 178 | +        ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False),  | 
182 | 179 |         ModuleFactory(WithSideEffect),  | 
183 | 180 |         ModuleFactory(Randomness),  | 
184 |  | -        ModuleFactory(WithModuleTrackingRunningStats),  | 
185 |  | -        param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda),  | 
 | 181 | +        ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True),  | 
 | 182 | +        param(  | 
 | 183 | +            ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True),  | 
 | 184 | +            marks=mark.xfail_if_cuda,  | 
 | 185 | +        ),  | 
186 | 186 |     ],  | 
187 | 187 | )  | 
188 | 188 | @mark.parametrize("batch_size", [1, 3, 32])  | 
@@ -397,9 +397,9 @@ def test_autograd_while_modules_are_hooked(  | 
397 | 397 | @mark.parametrize(  | 
398 | 398 |     ["factory", "batch_dim"],  | 
399 | 399 |     [  | 
400 |  | -        (ModuleFactory(WithModuleTrackingRunningStats), 0),  | 
401 |  | -        (ModuleFactory(WithRNN), 0),  | 
402 |  | -        (ModuleFactory(WithBatchNorm), 0),  | 
 | 400 | +        (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0),  | 
 | 401 | +        (ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0),  | 
 | 402 | +        (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0),  | 
403 | 403 |     ],  | 
404 | 404 | )  | 
405 | 405 | def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None):  | 
 | 
0 commit comments