Skip to content

Commit 8f5892c

Browse files
committed
refactor: PCGrad
1 parent 9737e2b commit 8f5892c

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

pytorch_optimizer/pcgrad.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from copy import deepcopy
33
from typing import Iterable, List, Tuple
44

5-
import numpy as np
65
import torch
76
from torch import nn
87
from torch.optim.optimizer import Optimizer
98

109
from pytorch_optimizer.base_optimizer import BaseOptimizer
10+
from pytorch_optimizer.utils import flatten_grad, un_flatten_grad
1111

1212

1313
class PCGrad(BaseOptimizer):
@@ -41,20 +41,6 @@ def validate_parameters(self):
4141
def reset(self):
4242
pass
4343

44-
@staticmethod
45-
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
46-
return torch.cat([g.flatten() for g in grads])
47-
48-
@staticmethod
49-
def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]:
50-
idx: int = 0
51-
un_flatten_grad: List[torch.Tensor] = []
52-
for shape in shapes:
53-
length = np.prod(shape)
54-
un_flatten_grad.append(grads[idx : idx + length].view(shape).clone())
55-
idx += length
56-
return un_flatten_grad
57-
5844
def zero_grad(self):
5945
return self.optimizer.zero_grad(set_to_none=True)
6046

@@ -97,8 +83,8 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List
9783

9884
grad, shape, has_grad = self.retrieve_grad()
9985

100-
grads.append(self.flatten_grad(grad))
101-
has_grads.append(self.flatten_grad(has_grad))
86+
grads.append(flatten_grad(grad))
87+
has_grads.append(flatten_grad(has_grad))
10288
shapes.append(shape)
10389

10490
return grads, shapes, has_grads
@@ -136,6 +122,6 @@ def pc_backward(self, objectives: Iterable[nn.Module]):
136122
"""
137123
grads, shapes, has_grads = self.pack_grad(objectives)
138124
pc_grad = self.project_conflicting(grads, has_grads)
139-
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
125+
pc_grad = un_flatten_grad(pc_grad, shapes[0])
140126

141127
self.set_grad(pc_grad)

pytorch_optimizer/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
from typing import List, Optional, Tuple, Union
33

4+
import numpy as np
45
import torch
56
from torch import nn
67
from torch.distributed import all_reduce
@@ -35,6 +36,20 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
3536
return x
3637

3738

39+
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
40+
return torch.cat([g.flatten() for g in grads])
41+
42+
43+
def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]:
44+
idx: int = 0
45+
un_flatten_grad: List[torch.Tensor] = []
46+
for shape in shapes:
47+
length = np.prod(shape)
48+
un_flatten_grad.append(grads[idx : idx + length].view(shape).clone())
49+
idx += length
50+
return un_flatten_grad
51+
52+
3853
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> Union[torch.Tensor, float]:
3954
"""Clips grad norms.
4055
During combination with FSDP, will also ensure that grad norms are aggregated

0 commit comments

Comments
 (0)