File tree Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Original file line number Diff line number Diff line change 168168from pytorch_optimizer .optimizer .utils import (
169169 CPUOffloadOptimizer ,
170170 clip_grad_norm ,
171+ copy_stochastic ,
171172 disable_running_stats ,
172173 enable_running_stats ,
173174 get_global_gradient_norm ,
Original file line number Diff line number Diff line change @@ -310,3 +310,26 @@ def reg_noise(
310310 loss .add_ (reg - noise .mul_ (noise_coef ).sum ())
311311
312312 return loss
313+
314+
315+ @torch .no_grad ()
316+ def copy_stochastic (target : torch .Tensor , source : torch .Tensor ) -> None :
317+ r"""Copy stochastic.
318+
319+ reference: https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905
320+
321+ :param target: torch.Tensor. bfloat16 tensor.
322+ :param source: torch.Tensor. float32 tensor.
323+ """
324+ result = torch .randint_like (
325+ source ,
326+ dtype = torch .int32 ,
327+ low = 0 ,
328+ high = 1 << 16 ,
329+ )
330+
331+ result .add_ (source .view (dtype = torch .int32 ))
332+
333+ result .bitwise_and_ (- 65536 )
334+
335+ target .copy_ (result .view (dtype = torch .float32 ))
You can’t perform that action at this time.
0 commit comments