Skip to content

Commit 5ee7a11

Browse files
authored
test(autogram): Add new architectures (#437)
* Add WithModuleWithStringArg * Add WithModuleWithStringOutput * Update test parametrizations
1 parent b075f6f commit 5ee7a11

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
WithBuffered,
5252
WithDropout,
5353
WithModuleTrackingRunningStats,
54+
WithModuleWithStringArg,
55+
WithModuleWithStringOutput,
5456
WithNoTensorOutput,
5557
WithRNN,
5658
WithSideEffect,
@@ -164,6 +166,8 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
164166
Randomness,
165167
WithModuleTrackingRunningStats,
166168
param(WithRNN, marks=mark.xfail_if_cuda),
169+
WithModuleWithStringArg,
170+
param(WithModuleWithStringOutput, marks=mark.xfail),
167171
],
168172
)
169173
@mark.parametrize("batch_size", [1, 3, 32])

tests/utils/architectures.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,54 @@ def forward(self, input: Tensor) -> Tensor:
772772
return input @ self.linear.weight.T + self.linear.bias
773773

774774

775+
class WithModuleWithStringArg(ShapedModule):
776+
"""Model containing a module that has a string argument."""
777+
778+
INPUT_SHAPES = (2,)
779+
OUTPUT_SHAPES = (3,)
780+
781+
class WithStringArg(nn.Module):
782+
def __init__(self):
783+
super().__init__()
784+
self.matrix = nn.Parameter(torch.randn(2, 3))
785+
786+
def forward(self, s: str, input: Tensor) -> Tensor:
787+
if s == "two":
788+
return input @ self.matrix * 2.0
789+
else:
790+
return input @ self.matrix
791+
792+
def __init__(self):
793+
super().__init__()
794+
self.with_string_arg = self.WithStringArg()
795+
796+
def forward(self, input: Tensor) -> Tensor:
797+
return self.with_string_arg("two", input)
798+
799+
800+
class WithModuleWithStringOutput(ShapedModule):
801+
"""Model containing a module that has a string output."""
802+
803+
INPUT_SHAPES = (2,)
804+
OUTPUT_SHAPES = (3,)
805+
806+
class WithStringOutput(nn.Module):
807+
def __init__(self):
808+
super().__init__()
809+
self.matrix = nn.Parameter(torch.randn(2, 3))
810+
811+
def forward(self, input: Tensor) -> tuple[str, Tensor]:
812+
return "test", input @ self.matrix
813+
814+
def __init__(self):
815+
super().__init__()
816+
self.with_string_output = self.WithStringOutput()
817+
818+
def forward(self, input: Tensor) -> Tensor:
819+
_, output = self.with_string_output(input)
820+
return output
821+
822+
775823
class FreeParam(ShapedModule):
776824
"""
777825
Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the

0 commit comments

Comments
 (0)