Skip to content

Commit bc93c1b

Browse files
committed
fix: device mismatch
1 parent 0b7c53b commit bc93c1b

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

pytorch_optimizer/optimizer/adan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
8383
if self.defaults['max_grad_norm'] == 0.0:
8484
return 1.0
8585

86-
global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
86+
global_grad_norm = get_global_gradient_norm(self.param_groups)
8787
global_grad_norm.sqrt_().add_(self.defaults['eps'])
8888

8989
return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

pytorch_optimizer/optimizer/alig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def reset(self):
5252
@torch.no_grad()
5353
def compute_step_size(self, loss: float) -> float:
5454
r"""Compute step_size."""
55-
global_grad_norm = get_global_gradient_norm(self.param_groups, torch.device('cpu'))
55+
global_grad_norm = get_global_gradient_norm(self.param_groups)
5656
global_grad_norm.add_(1e-6)
5757

5858
return loss / global_grad_norm.item()

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
473473
numerator_weighted = group['numerator_weighted']
474474

475475
if group['step'] == 0:
476-
group['g0_norm'] = get_global_gradient_norm(self.param_groups, device).sqrt_().item()
476+
group['g0_norm'] = get_global_gradient_norm(self.param_groups).sqrt_().item()
477477
g0_norm = group['g0_norm']
478478

479479
if g0_norm == 0:

pytorch_optimizer/optimizer/lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
103103
if self.defaults['max_grad_norm'] == 0.0:
104104
return 1.0
105105

106-
global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
106+
global_grad_norm = get_global_gradient_norm(self.param_groups)
107107
global_grad_norm.sqrt_().add_(self.defaults['eps'])
108108

109109
return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

pytorch_optimizer/optimizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def l2_projection(parameters: PARAMETERS, max_norm: float = 1e2):
272272

273273

274274
@torch.no_grad()
275-
def get_global_gradient_norm(param_groups: List[Dict], device: torch.device) -> torch.Tensor:
275+
def get_global_gradient_norm(param_groups: List[Dict]) -> torch.Tensor:
276276
r"""Get global gradient norm."""
277-
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=device)
277+
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=param_groups[0]['params'][0].device)
278278

279279
for group in param_groups:
280280
for p in group['params']:

0 commit comments

Comments
 (0)