Skip to content

Commit 49d3937

Browse files
authored
Merge pull request #34 from kozistr/feature/dummy
[Refactor] Refactor the codes
2 parents 3934366 + 7599d81 commit 49d3937

File tree

3 files changed

+43
-44
lines changed

3 files changed

+43
-44
lines changed

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from pytorch_optimizer.sam import SAM
1818
from pytorch_optimizer.sgdp import SGDP
1919

20-
__VERSION__ = '0.0.11'
20+
__VERSION__ = '0.1.0'

pytorch_optimizer/pcgrad.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,41 @@ def set_grad(self, grads):
6161
p.grad = grads[idx]
6262
idx += 1
6363

64-
def pc_backward(self, objectives: Iterable[nn.Module]):
65-
"""Calculate the gradient of the parameters
66-
:param objectives: Iterable[nn.Module]. a list of objectives
64+
def retrieve_grad(self):
65+
"""get the gradient of the parameters of the network with specific objective"""
66+
grad, shape, has_grad = [], [], []
67+
for group in self.optimizer.param_groups:
68+
for p in group['params']:
69+
if p.grad is None:
70+
shape.append(p.shape)
71+
grad.append(torch.zeros_like(p).to(p.device))
72+
has_grad.append(torch.zeros_like(p).to(p.device))
73+
continue
74+
75+
shape.append(p.grad.shape)
76+
grad.append(p.grad.clone())
77+
has_grad.append(torch.ones_like(p).to(p.device))
78+
79+
return grad, shape, has_grad
80+
81+
def pack_grad(self, objectives: Iterable[nn.Module]):
82+
"""pack the gradient of the parameters of the network for each objective
83+
:param objectives: Iterable[float]. a list of objectives
6784
:return:
6885
"""
69-
grads, shapes, has_grads = self.pack_grad(objectives)
70-
pc_grad = self.project_conflicting(grads, has_grads)
71-
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
72-
self.set_grad(pc_grad)
86+
grads, shapes, has_grads = [], [], []
87+
for objective in objectives:
88+
self.zero_grad()
89+
90+
objective.backward(retain_graph=True)
91+
92+
grad, shape, has_grad = self.retrieve_grad()
93+
94+
grads.append(self.flatten_grad(grad))
95+
has_grads.append(self.flatten_grad(has_grad))
96+
shapes.append(shape)
97+
98+
return grads, shapes, has_grads
7399

74100
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
75101
"""
@@ -99,40 +125,13 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
99125

100126
return merged_grad
101127

102-
def retrieve_grad(self):
103-
"""Get the gradient of the parameters of the network with specific objective
104-
:return:
105-
"""
106-
grad, shape, has_grad = [], [], []
107-
for group in self.optimizer.param_groups:
108-
for p in group['params']:
109-
if p.grad is None:
110-
shape.append(p.shape)
111-
grad.append(torch.zeros_like(p).to(p.device))
112-
has_grad.append(torch.zeros_like(p).to(p.device))
113-
continue
114-
115-
shape.append(p.grad.shape)
116-
grad.append(p.grad.clone())
117-
has_grad.append(torch.ones_like(p).to(p.device))
118-
119-
return grad, shape, has_grad
120-
121-
def pack_grad(self, objectives: Iterable[nn.Module]):
122-
"""Pack the gradient of the parameters of the network for each objective
123-
:param objectives: Iterable[float]. a list of objectives
128+
def pc_backward(self, objectives: Iterable[nn.Module]):
129+
"""calculate the gradient of the parameters
130+
:param objectives: Iterable[nn.Module]. a list of objectives
124131
:return:
125132
"""
126-
grads, shapes, has_grads = [], [], []
127-
for objective in objectives:
128-
self.zero_grad()
129-
130-
objective.backward(retain_graph=True)
131-
132-
grad, shape, has_grad = self.retrieve_grad()
133-
134-
grads.append(self.flatten_grad(grad))
135-
has_grads.append(self.flatten_grad(has_grad))
136-
shapes.append(shape)
133+
grads, shapes, has_grads = self.pack_grad(objectives)
134+
pc_grad = self.project_conflicting(grads, has_grads)
135+
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
137136

138-
return grads, shapes, has_grads
137+
self.set_grad(pc_grad)

pytorch_optimizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
2626
return x
2727

2828

29-
def unit_norm(x: torch.Tensor) -> torch.Tensor:
29+
def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
3030
keep_dim: bool = True
3131
dim: Optional[Union[int, Tuple[int, ...]]] = None
3232

@@ -40,4 +40,4 @@ def unit_norm(x: torch.Tensor) -> torch.Tensor:
4040
else:
4141
dim = tuple(range(1, x_len))
4242

43-
return x.norm(dim=dim, keepdim=keep_dim, p=2.0)
43+
return x.norm(dim=dim, keepdim=keep_dim, p=norm)

0 commit comments

Comments
 (0)