We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7c0b867 commit cbc45b6Copy full SHA for cbc45b6
pytorch_optimizer/optimizer/utils.py
@@ -199,6 +199,7 @@ def neuron_mean(x: torch.Tensor) -> torch.Tensor:
199
200
def disable_running_stats(model):
201
r"""disable running stats (momentum) of BatchNorm"""
202
+
203
def _disable(module):
204
if isinstance(module, _BatchNorm):
205
module.backup_momentum = module.momentum
@@ -209,6 +210,7 @@ def _disable(module):
209
210
211
def enable_running_stats(model):
212
r"""enable running stats (momentum) of BatchNorm"""
213
214
def _enable(module):
215
if isinstance(module, _BatchNorm) and hasattr(module, 'backup_momentum'):
216
module.momentum = module.backup_momentum
0 commit comments