Skip to content

Commit 8c8b821

Browse files
committed
docs: E-MCMC
1 parent 2da8916 commit 8c8b821

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

docs/util.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@
8484
:docstring:
8585
:members:
8686

87+
::: pytorch_optimizer.optimizer.utils.reg_noise
88+
:docstring:
89+
:members:
90+
8791
## Newton methods
8892

8993
::: pytorch_optimizer.optimizer.shampoo_utils.power_iteration

pytorch_optimizer/optimizer/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
283283
def reg_noise(
284284
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
285285
) -> torch.Tensor | float:
286+
r"""Entropy-MCMC: Sampling from flat basins with ease.
287+
288+
usage: https://github.com/lblaoke/EMCMC/blob/master/exp/cifar10_emcmc.py
289+
290+
:param network1: nn.Module. network.
291+
:param network2: nn.Module. network.
292+
:param num_data: int. number of training data.
293+
:param lr: float. learning rate.
294+
:param eta: float. eta.
295+
:param temperature: float. temperature.
296+
"""
286297
reg_coef: float = 0.5 / (eta * num_data)
287298
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)
288299

0 commit comments

Comments
 (0)