|
6 | 6 | import torch |
7 | 7 | from pytest import mark, param |
8 | 8 | from torch import Tensor |
9 | | -from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear |
| 9 | +from torch.nn import BatchNorm2d, InstanceNorm2d, Linear |
10 | 10 | from torch.optim import SGD |
11 | 11 | from torch.testing import assert_close |
12 | 12 | from utils.architectures import ( |
|
56 | 56 | WithModuleWithStringOutput, |
57 | 57 | WithMultiHeadAttention, |
58 | 58 | WithNoTensorOutput, |
| 59 | + WithRNN, |
59 | 60 | WithSideEffect, |
60 | 61 | WithSomeFrozenModule, |
61 | 62 | WithTransformer, |
@@ -179,10 +180,7 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int |
179 | 180 | ModuleFactory(WithSideEffect), |
180 | 181 | ModuleFactory(Randomness), |
181 | 182 | ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), |
182 | | - param( |
183 | | - ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), |
184 | | - marks=mark.xfail_if_cuda, |
185 | | - ), |
| 183 | + param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), |
186 | 184 | ], |
187 | 185 | ) |
188 | 186 | @mark.parametrize("batch_size", [1, 3, 32]) |
@@ -398,7 +396,7 @@ def test_autograd_while_modules_are_hooked( |
398 | 396 | ["factory", "batch_dim"], |
399 | 397 | [ |
400 | 398 | (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0), |
401 | | - (ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0), |
| 399 | + param(ModuleFactory(WithRNN), 0), |
402 | 400 | (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), |
403 | 401 | ], |
404 | 402 | ) |
|
0 commit comments