File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change 2727 CPUOffloadOptimizer ,
2828 clip_grad_norm ,
2929 compare_versions ,
30+ copy_stochastic ,
3031 disable_running_stats ,
3132 enable_running_stats ,
3233 has_overflow ,
@@ -348,3 +349,19 @@ def test_zero_power_via_newton_schulz_5():
348349
349350 with pytest .raises (ValueError ):
350351 _ = zero_power_via_newton_schulz_5 (x [0 ], num_steps = 6 )
352+
353+
354+ def test_copy_stochastic ():
355+ n : int = 512
356+
357+ a = torch .full ((n ,), 1.0 , dtype = torch .bfloat16 )
358+ b = torch .full ((n ,), 0.0002 , dtype = torch .bfloat16 )
359+ result = torch .full ((n ,), 0.0 , dtype = torch .bfloat16 )
360+
361+ added = a .to (dtype = torch .float32 ) + b
362+
363+ result .copy_ (added )
364+ np .testing .assert_almost_equal (1.0000 , result .to (dtype = torch .float32 ).mean ().item (), decimal = 4 )
365+
366+ copy_stochastic (result , added )
367+ np .testing .assert_almost_equal (1.0002 , result .to (dtype = torch .float32 ).mean ().item (), decimal = 4 )
You can’t perform that action at this time.
0 commit comments