Skip to content

Commit 05c737d

Browse files
committed
update: logging_active to False
1 parent 87c17b1 commit 05c737d

File tree

1 file changed

+7
-25
lines changed

1 file changed

+7
-25
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
decay_type: str = 'stable',
8383
warmup_type: str = 'linear',
8484
warmup_pct_default: float = 0.22,
85-
logging_active: bool = True,
85+
logging_active: bool = False,
8686
):
8787
"""Ranger optimizer (RAdam + Lookahead + Gradient Centralization, combined into one optimizer)
8888
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
@@ -361,10 +361,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
361361
if p.grad is None:
362362
continue
363363

364-
# if not self.param_size:
365364
param_size += p.numel()
366365

367-
# apply agc if enabled
368366
if self.agc_active:
369367
agc(
370368
p, agc_eps=self.agc_eps, agc_clip_val=self.agc_clip_val
@@ -377,43 +375,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
377375

378376
state = self.state[p]
379377

380-
# State initialization
381378
if len(state) == 0:
382379
state['step'] = 0
383-
384-
# Exponential moving average of gradient values
385-
state['grad_ma'] = torch.zeros_like(
386-
p, memory_format=torch.preserve_format
387-
)
388-
# Exponential moving average of squared gradient values
389-
state['variance_ma'] = torch.zeros_like(
390-
p, memory_format=torch.preserve_format
391-
)
380+
state['grad_ma'] = torch.zeros_like(p)
381+
state['variance_ma'] = torch.zeros_like(p)
392382

393383
if self.lookahead_active:
394384
state['lookahead_params'] = torch.zeros_like(p.data)
395385
state['lookahead_params'].copy_(p.data)
396386

397387
if self.use_adabelief:
398-
state['variance_ma_belief'] = torch.zeros_like(
399-
p, memory_format=torch.preserve_format
400-
)
388+
state['variance_ma_belief'] = torch.zeros_like(p)
401389
if self.momentum_pnm:
402-
state['neg_grad_ma'] = torch.zeros_like(
403-
p, memory_format=torch.preserve_format
404-
)
405-
406-
# Maintains max of all exp. moving avg. of sq. grad. values
407-
state['max_variance_ma'] = torch.zeros_like(
408-
p, memory_format=torch.preserve_format
409-
)
390+
state['neg_grad_ma'] = torch.zeros_like(p)
391+
state['max_variance_ma'] = torch.zeros_like(p)
410392

411-
# centralize gradients
412393
if self.use_gc:
413394
grad = centralize_gradient(
414395
grad,
415396
gc_conv_only=self.gc_conv_only,
416397
)
398+
417399
if self.use_gc_norm:
418400
grad = normalize_gradient(grad)
419401

0 commit comments

Comments
 (0)