Skip to content

Commit 0b7c3f6

Browse files
authored
test: Fix RNN testing (#463)
* Revert removal of WithRNN (part of aff0abc) * Fix output of WithRNN to not include the hidden state
1 parent aff0abc commit 0b7c3f6

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from pytest import mark, param
88
from torch import Tensor
9-
from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear
9+
from torch.nn import BatchNorm2d, InstanceNorm2d, Linear
1010
from torch.optim import SGD
1111
from torch.testing import assert_close
1212
from utils.architectures import (
@@ -56,6 +56,7 @@
5656
WithModuleWithStringOutput,
5757
WithMultiHeadAttention,
5858
WithNoTensorOutput,
59+
WithRNN,
5960
WithSideEffect,
6061
WithSomeFrozenModule,
6162
WithTransformer,
@@ -179,10 +180,7 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int
179180
ModuleFactory(WithSideEffect),
180181
ModuleFactory(Randomness),
181182
ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True),
182-
param(
183-
ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True),
184-
marks=mark.xfail_if_cuda,
185-
),
183+
param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda),
186184
],
187185
)
188186
@mark.parametrize("batch_size", [1, 3, 32])
@@ -398,7 +396,7 @@ def test_autograd_while_modules_are_hooked(
398396
["factory", "batch_dim"],
399397
[
400398
(ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0),
401-
(ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0),
399+
param(ModuleFactory(WithRNN), 0),
402400
(ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0),
403401
],
404402
)

tests/utils/architectures.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,6 @@ def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]:
4545
if isinstance(module, ShapedModule):
4646
return module.INPUT_SHAPES, module.OUTPUT_SHAPES
4747

48-
elif isinstance(module, nn.RNN):
49-
assert module.batch_first
50-
SEQ_LEN = 20 # Arbitrary choice
51-
return (SEQ_LEN, module.input_size), (SEQ_LEN, module.hidden_size)
52-
5348
elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
5449
HEIGHT = 6 # Arbitrary choice
5550
WIDTH = 6 # Arbitrary choice
@@ -737,6 +732,21 @@ def forward(self, input: Tensor) -> Tensor:
737732
return torch.einsum("bi,icdef->bcdef", input, self.tensor)
738733

739734

735+
class WithRNN(ShapedModule):
736+
"""Simple model containing an RNN module."""
737+
738+
INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8)
739+
OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5)
740+
741+
def __init__(self):
742+
super().__init__()
743+
self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True)
744+
745+
def forward(self, input: Tensor) -> Tensor:
746+
output, _ = self.rnn(input)
747+
return output
748+
749+
740750
class WithDropout(ShapedModule):
741751
"""Simple model containing Dropout layers."""
742752

0 commit comments

Comments
 (0)