Skip to content

Commit c90a921

Browse files
pstjohnksivaman
andauthored
Add tests that reset_parameters doesn't change parameter initial value ranges (#2550)
* Add tests for 2528 and 2529 Signed-off-by: Peter St. John <pstjohn@nvidia.com> * Update tests/pytorch/test_deferred_init.py Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update tests/pytorch/test_deferred_init.py Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 4f364c8 commit c90a921

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

tests/pytorch/test_deferred_init.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929

3030
class TestDeferredInit:
31-
3231
@staticmethod
3332
def get_module_args(module):
3433
hidden_size = num_heads * head_dim
@@ -82,3 +81,45 @@ def test_reset_parameters(
8281
"on CUDA device"
8382
)
8483
del module
84+
85+
@pytest.mark.parametrize("module_type", _core_modules)
86+
def test_reset_parameters_doesnt_change_parameter_stats(
87+
self,
88+
module_type: torch.nn.Module,
89+
) -> None:
90+
"""Test for github issue #2528 and #2529 to ensure that reset_parameters() doesn't change
91+
the parameter mean and std"""
92+
args, kwargs = TestDeferredInit.get_module_args(module_type)
93+
kwargs["device"] = "cuda"
94+
module = module_type(*args, **kwargs)
95+
96+
param_stats = {
97+
name: {"mean": param.mean(), "std": param.std()}
98+
for name, param in module.named_parameters()
99+
}
100+
101+
with torch.no_grad():
102+
module.reset_parameters()
103+
104+
param_stats_after = {
105+
name: {"mean": param.mean(), "std": param.std()}
106+
for name, param in module.named_parameters()
107+
}
108+
109+
for name, stats in param_stats_after.items():
110+
torch.testing.assert_close(
111+
stats["mean"],
112+
param_stats[name]["mean"],
113+
atol=1e-3,
114+
rtol=1e-3,
115+
msg=f"{name} mean changed after reset_parameters",
116+
)
117+
torch.testing.assert_close(
118+
stats["std"],
119+
param_stats[name]["std"],
120+
atol=1e-3,
121+
rtol=1e-3,
122+
msg=f"{name} std changed after reset_parameters",
123+
)
124+
125+
del module

0 commit comments

Comments
 (0)