Skip to content

Commit d6afc04

Browse files
committed
update: get_global_gradient_norm
1 parent 9ae964c commit d6afc04

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

pytorch_optimizer/optimizer/lamb.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
106106
if self.defaults['max_grad_norm'] == 0.0:
107107
return 1.0
108108

109-
device = self.param_groups[0]['params'][0].device
110-
111-
global_grad_norm = torch.zeros(1, device=device)
112-
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
109+
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=self.param_groups[0]['params'][0].device)
113110

114111
for group in self.param_groups:
115112
for p in group['params']:
@@ -118,7 +115,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
118115

119116
global_grad_norm.sqrt_()
120117

121-
return torch.clamp(max_grad_norm / (global_grad_norm + self.eps), max=1.0)
118+
return torch.clamp(self.defaults['max_grad_norm'] / (global_grad_norm + self.eps), max=1.0)
122119

123120
@torch.no_grad()
124121
def step(self, closure: CLOSURE = None) -> LOSS:

0 commit comments

Comments
 (0)