Skip to content

Commit dd577cf

Browse files
authored
test(autogram): Add ModelAlsoUsingSubmoduleParamsDirectly (#446)
1 parent 6d7d140 commit dd577cf

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
InterModuleParamReuse,
2121
MIMOBranched,
2222
MISOBranched,
23+
ModelAlsoUsingSubmoduleParamsDirectly,
2324
ModelUsingSubmoduleParamsDirectly,
2425
ModuleReuse,
2526
MultiInputMultiOutput,
@@ -190,7 +191,9 @@ def test_compute_gramian_with_weird_modules(
190191

191192

192193
@mark.xfail
193-
@mark.parametrize("architecture", [ModelUsingSubmoduleParamsDirectly])
194+
@mark.parametrize(
195+
"architecture", [ModelUsingSubmoduleParamsDirectly, ModelAlsoUsingSubmoduleParamsDirectly]
196+
)
194197
@mark.parametrize("batch_size", [1, 3, 32])
195198
@mark.parametrize("batch_dim", [0, None])
196199
def test_compute_gramian_unsupported_architectures(

tests/utils/architectures.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,22 @@ def forward(self, input: Tensor) -> Tensor:
772772
return input @ self.linear.weight.T + self.linear.bias
773773

774774

775+
class ModelAlsoUsingSubmoduleParamsDirectly(ShapedModule):
776+
"""
777+
Model that uses its submodule's parameters directly but that also calls its submodule's forward.
778+
"""
779+
780+
INPUT_SHAPES = (2,)
781+
OUTPUT_SHAPES = (3,)
782+
783+
def __init__(self):
784+
super().__init__()
785+
self.linear = nn.Linear(2, 3)
786+
787+
def forward(self, input: Tensor) -> Tensor:
788+
return input @ self.linear.weight.T + self.linear.bias + self.linear(input)
789+
790+
775791
class _WithStringArg(nn.Module):
776792
def __init__(self):
777793
super().__init__()

0 commit comments

Comments
 (0)