|
2 | 2 | from copy import deepcopy |
3 | 3 | from typing import Iterable, List, Tuple |
4 | 4 |
|
5 | | -import numpy as np |
6 | 5 | import torch |
7 | 6 | from torch import nn |
8 | 7 | from torch.optim.optimizer import Optimizer |
9 | 8 |
|
10 | 9 | from pytorch_optimizer.base_optimizer import BaseOptimizer |
| 10 | +from pytorch_optimizer.utils import flatten_grad, un_flatten_grad |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class PCGrad(BaseOptimizer): |
@@ -41,20 +41,6 @@ def validate_parameters(self): |
41 | 41 | def reset(self): |
42 | 42 | pass |
43 | 43 |
|
44 | | - @staticmethod |
45 | | - def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor: |
46 | | - return torch.cat([g.flatten() for g in grads]) |
47 | | - |
48 | | - @staticmethod |
49 | | - def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]: |
50 | | - idx: int = 0 |
51 | | - un_flatten_grad: List[torch.Tensor] = [] |
52 | | - for shape in shapes: |
53 | | - length = np.prod(shape) |
54 | | - un_flatten_grad.append(grads[idx : idx + length].view(shape).clone()) |
55 | | - idx += length |
56 | | - return un_flatten_grad |
57 | | - |
58 | 44 | def zero_grad(self): |
59 | 45 | return self.optimizer.zero_grad(set_to_none=True) |
60 | 46 |
|
@@ -97,8 +83,8 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List |
97 | 83 |
|
98 | 84 | grad, shape, has_grad = self.retrieve_grad() |
99 | 85 |
|
100 | | - grads.append(self.flatten_grad(grad)) |
101 | | - has_grads.append(self.flatten_grad(has_grad)) |
| 86 | + grads.append(flatten_grad(grad)) |
| 87 | + has_grads.append(flatten_grad(has_grad)) |
102 | 88 | shapes.append(shape) |
103 | 89 |
|
104 | 90 | return grads, shapes, has_grads |
@@ -136,6 +122,6 @@ def pc_backward(self, objectives: Iterable[nn.Module]): |
136 | 122 | """ |
137 | 123 | grads, shapes, has_grads = self.pack_grad(objectives) |
138 | 124 | pc_grad = self.project_conflicting(grads, has_grads) |
139 | | - pc_grad = self.un_flatten_grad(pc_grad, shapes[0]) |
| 125 | + pc_grad = un_flatten_grad(pc_grad, shapes[0]) |
140 | 126 |
|
141 | 127 | self.set_grad(pc_grad) |
0 commit comments