Skip to content

Commit 22affdb

Browse files
committed
update: test_copy_stochastic
1 parent 636fec8 commit 22affdb

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
CPUOffloadOptimizer,
2828
clip_grad_norm,
2929
compare_versions,
30+
copy_stochastic,
3031
disable_running_stats,
3132
enable_running_stats,
3233
has_overflow,
@@ -348,3 +349,19 @@ def test_zero_power_via_newton_schulz_5():
348349

349350
with pytest.raises(ValueError):
350351
_ = zero_power_via_newton_schulz_5(x[0], num_steps=6)
352+
353+
354+
def test_copy_stochastic():
355+
n: int = 512
356+
357+
a = torch.full((n,), 1.0, dtype=torch.bfloat16)
358+
b = torch.full((n,), 0.0002, dtype=torch.bfloat16)
359+
result = torch.full((n,), 0.0, dtype=torch.bfloat16)
360+
361+
added = a.to(dtype=torch.float32) + b
362+
363+
result.copy_(added)
364+
np.testing.assert_almost_equal(1.0000, result.to(dtype=torch.float32).mean().item(), decimal=4)
365+
366+
copy_stochastic(result, added)
367+
np.testing.assert_almost_equal(1.0002, result.to(dtype=torch.float32).mean().item(), decimal=4)

0 commit comments

Comments
 (0)