Skip to content

Commit cbc45b6

Browse files
committed
update: running_stats
1 parent 7c0b867 commit cbc45b6

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def neuron_mean(x: torch.Tensor) -> torch.Tensor:
199199

200200
def disable_running_stats(model):
201201
r"""disable running stats (momentum) of BatchNorm"""
202+
202203
def _disable(module):
203204
if isinstance(module, _BatchNorm):
204205
module.backup_momentum = module.momentum
@@ -209,6 +210,7 @@ def _disable(module):
209210

210211
def enable_running_stats(model):
211212
r"""enable running stats (momentum) of BatchNorm"""
213+
212214
def _enable(module):
213215
if isinstance(module, _BatchNorm) and hasattr(module, 'backup_momentum'):
214216
module.momentum = module.backup_momentum

0 commit comments

Comments
 (0)