Skip to content

Commit aff0abc

Browse files
authored
test: Remove trivial ShapedModules (#461)
* Add support for RNN, BatchNorm2d and InstanceNorm2d in get_in_out_shapes * Remove WithRNN, WithBatchNorm and WithModuleTrackingRunningStats - use simple factories instead
1 parent 290a393 commit aff0abc

File tree

2 files changed

+22
-52
lines changed

2 files changed

+22
-52
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from pytest import mark, param
88
from torch import Tensor
9-
from torch.nn import Linear
9+
from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear
1010
from torch.optim import SGD
1111
from torch.testing import assert_close
1212
from utils.architectures import (
@@ -47,18 +47,15 @@
4747
SomeUnusedOutput,
4848
SomeUnusedParam,
4949
SqueezeNet,
50-
WithBatchNorm,
5150
WithBuffered,
5251
WithDropout,
53-
WithModuleTrackingRunningStats,
5452
WithModuleWithHybridPyTreeArg,
5553
WithModuleWithHybridPyTreeKwarg,
5654
WithModuleWithStringArg,
5755
WithModuleWithStringKwarg,
5856
WithModuleWithStringOutput,
5957
WithMultiHeadAttention,
6058
WithNoTensorOutput,
61-
WithRNN,
6259
WithSideEffect,
6360
WithSomeFrozenModule,
6461
WithTransformer,
@@ -178,11 +175,14 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int
178175
@mark.parametrize(
179176
"factory",
180177
[
181-
ModuleFactory(WithBatchNorm),
178+
ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False),
182179
ModuleFactory(WithSideEffect),
183180
ModuleFactory(Randomness),
184-
ModuleFactory(WithModuleTrackingRunningStats),
185-
param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda),
181+
ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True),
182+
param(
183+
ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True),
184+
marks=mark.xfail_if_cuda,
185+
),
186186
],
187187
)
188188
@mark.parametrize("batch_size", [1, 3, 32])
@@ -397,9 +397,9 @@ def test_autograd_while_modules_are_hooked(
397397
@mark.parametrize(
398398
["factory", "batch_dim"],
399399
[
400-
(ModuleFactory(WithModuleTrackingRunningStats), 0),
401-
(ModuleFactory(WithRNN), 0),
402-
(ModuleFactory(WithBatchNorm), 0),
400+
(ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0),
401+
(ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0),
402+
(ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0),
403403
],
404404
)
405405
def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None):

tests/utils/architectures.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def __init_subclass__(cls):
4444
def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]:
4545
if isinstance(module, ShapedModule):
4646
return module.INPUT_SHAPES, module.OUTPUT_SHAPES
47+
48+
elif isinstance(module, nn.RNN):
49+
assert module.batch_first
50+
SEQ_LEN = 20 # Arbitrary choice
51+
return (SEQ_LEN, module.input_size), (SEQ_LEN, module.hidden_size)
52+
53+
elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
54+
HEIGHT = 6 # Arbitrary choice
55+
WIDTH = 6 # Arbitrary choice
56+
shape = (module.num_features, HEIGHT, WIDTH)
57+
return shape, shape
58+
4759
else:
4860
raise ValueError("Unknown input / output shapes of module", module)
4961

@@ -725,48 +737,6 @@ def forward(self, input: Tensor) -> Tensor:
725737
return torch.einsum("bi,icdef->bcdef", input, self.tensor)
726738

727739

728-
class WithRNN(ShapedModule):
729-
"""Simple model containing an RNN module."""
730-
731-
INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8)
732-
OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5)
733-
734-
def __init__(self):
735-
super().__init__()
736-
self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True)
737-
738-
def forward(self, input: Tensor) -> Tensor:
739-
return self.rnn(input)
740-
741-
742-
class WithModuleTrackingRunningStats(ShapedModule):
743-
"""Simple model containing a module that has side-effects and modifies tensors in-place."""
744-
745-
INPUT_SHAPES = (3, 6, 6)
746-
OUTPUT_SHAPES = (3, 6, 6)
747-
748-
def __init__(self):
749-
super().__init__()
750-
self.instance_norm = nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
751-
752-
def forward(self, input: Tensor) -> Tensor:
753-
return self.instance_norm(input)
754-
755-
756-
class WithBatchNorm(ShapedModule):
757-
"""Simple model containing a BatchNorm layer."""
758-
759-
INPUT_SHAPES = (3, 6, 6)
760-
OUTPUT_SHAPES = (3, 6, 6)
761-
762-
def __init__(self):
763-
super().__init__()
764-
self.batch_norm = nn.BatchNorm2d(3, affine=True, track_running_stats=False)
765-
766-
def forward(self, input: Tensor) -> Tensor:
767-
return self.batch_norm(input)
768-
769-
770740
class WithDropout(ShapedModule):
771741
"""Simple model containing Dropout layers."""
772742

0 commit comments

Comments
 (0)