Skip to content

Commit f68346c

Browse files
committed
feature: copy_stochastic
1 parent 22affdb commit f68346c

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
from pytorch_optimizer.optimizer.utils import (
169169
CPUOffloadOptimizer,
170170
clip_grad_norm,
171+
copy_stochastic,
171172
disable_running_stats,
172173
enable_running_stats,
173174
get_global_gradient_norm,

pytorch_optimizer/optimizer/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,26 @@ def reg_noise(
310310
loss.add_(reg - noise.mul_(noise_coef).sum())
311311

312312
return loss
313+
314+
315+
@torch.no_grad()
316+
def copy_stochastic(target: torch.Tensor, source: torch.Tensor) -> None:
317+
r"""Copy stochastic.
318+
319+
reference: https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905
320+
321+
:param target: torch.Tensor. bfloat16 tensor.
322+
:param source: torch.Tensor. float32 tensor.
323+
"""
324+
result = torch.randint_like(
325+
source,
326+
dtype=torch.int32,
327+
low=0,
328+
high=1 << 16,
329+
)
330+
331+
result.add_(source.view(dtype=torch.int32))
332+
333+
result.bitwise_and_(-65536)
334+
335+
target.copy_(result.view(dtype=torch.float32))

0 commit comments

Comments
 (0)