Skip to content

Commit 22f994b

Browse files
authored
Merge pull request #259 from kozistr/feature/gradient-release
[Refactor] Stuffs
2 parents 3d4d440 + 1d9dfb0 commit 22f994b

File tree

5 files changed

+57
-53
lines changed

5 files changed

+57
-53
lines changed

poetry.lock

Lines changed: 1 addition & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ classifiers = [
4747
python = ">=3.8,<4.0.0"
4848
numpy = { version = "*", python = ">=3.8" }
4949
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
50+
bitsandbytes = { version = "^0.43", optional = true }
5051

5152
[tool.poetry.dev-dependencies]
5253
isort = { version = "^5", python = ">=3.8" }
@@ -55,6 +56,9 @@ ruff = "*"
5556
pytest = "*"
5657
pytest-cov = "*"
5758

59+
[tool.poetry.extras]
60+
bitsandbytes = ["bitsandbytes"]
61+
5862
[[tool.poetry.source]]
5963
name = "torch"
6064
url = "https://download.pytorch.org/whl/cpu"

pytorch_optimizer/optimizer/rotograd.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
from importlib.util import find_spec
12
from typing import Any, List, Optional, Sequence
23

34
import torch
45
from torch import nn
56

6-
try:
7-
from geotorch import orthogonal
7+
HAS_GEOTORCH: bool = find_spec('geotorch') is not None
88

9-
HAS_GEOTORCH = True
10-
except ImportError:
11-
HAS_GEOTORCH = False
9+
if HAS_GEOTORCH:
10+
from geotorch import orthogonal
1211

1312

14-
def divide(numer, denom, eps: float = 1e-15):
13+
def divide(numer: torch.Tensor, de_nom: torch.Tensor, eps: float = 1e-15) -> torch.Tensor:
1514
r"""Numerically stable division."""
1615
return (
17-
torch.sign(numer) * torch.sign(denom) * torch.exp(torch.log(numer.abs() + eps) - torch.log(denom.abs() + eps))
16+
torch.sign(numer)
17+
* torch.sign(de_nom)
18+
* torch.exp(torch.log(numer.abs() + eps) - torch.log(de_nom.abs() + eps))
1819
)
1920

2021

@@ -181,7 +182,7 @@ def __init__(
181182
):
182183
super().__init__()
183184
if not HAS_GEOTORCH:
184-
raise ImportError('[-] you need to install geotorch to use RotoGrad. pip install geotorch')
185+
raise ImportError('[-] you need to install `geotorch` to use RotoGrad. `pip install geotorch`')
185186

186187
self._backbone = [backbone]
187188
self.heads = heads

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
336336

337337
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
338338

339-
if group['weight_decay'] > 0.0:
340-
if not group['decoupled_weight_decay']:
341-
graft_grad.add_(p, alpha=group['weight_decay'])
342-
shampoo_grad.add_(p, alpha=group['weight_decay'])
343-
else:
344-
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
345-
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
339+
for g in (graft_grad, shampoo_grad):
340+
self.apply_weight_decay(
341+
p,
342+
g,
343+
group['lr'],
344+
group['weight_decay'],
345+
group['decoupled_weight_decay'],
346+
fixed_decay=False,
347+
)
346348

347349
state['momentum'].mul_(beta1).add_(shampoo_grad)
348350
graft_momentum = graft.update_momentum(grad, beta1)

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Graft:
2929
def __init__(self, *args):
3030
pass
3131

32-
def add_statistics(self, grad: torch.Tensor, unused_beta2: float):
32+
def add_statistics(self, grad: torch.Tensor, unused_beta2: float) -> None:
3333
r"""Add the statistics."""
3434
pass
3535

@@ -47,7 +47,7 @@ class SGDGraft(Graft):
4747

4848
def __init__(self, var: torch.Tensor):
4949
super().__init__(var)
50-
self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device)
50+
self.momentum: torch.Tensor = torch.zeros_like(var)
5151

5252
def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
5353
r"""Update momentum."""
@@ -78,13 +78,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
7878
self.diagonal_eps = diagonal_eps
7979
self.statistics: torch.Tensor = torch.zeros_like(var)
8080

81-
def add_statistics(self, grad: torch.Tensor, _):
81+
def add_statistics(self, grad: torch.Tensor, _) -> None:
8282
r"""Add the statistics."""
8383
self.statistics.add_(grad.pow(2))
8484

8585
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
8686
r"""Get preconditioned gradient."""
87-
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
87+
return grad.div(self.statistics.sqrt().add_(self.diagonal_eps))
8888

8989

9090
class RMSPropGraft(SGDGraft):
@@ -99,13 +99,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
9999
self.diagonal_eps = diagonal_eps
100100
self.statistics: torch.Tensor = torch.zeros_like(var)
101101

102-
def add_statistics(self, grad: torch.Tensor, beta2: float):
102+
def add_statistics(self, grad: torch.Tensor, beta2: float) -> None:
103103
r"""Add the statistics."""
104104
self.statistics.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
105105

106106
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
107107
r"""Get preconditioned gradient."""
108-
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
108+
return grad.div(self.statistics.sqrt().add_(self.diagonal_eps))
109109

110110

111111
class BlockPartitioner:
@@ -121,51 +121,51 @@ class BlockPartitioner:
121121
"""
122122

123123
def __init__(self, var: torch.Tensor, rank: int, block_size: int, pre_conditioner_type: int):
124-
self.shape: List[int] = var.shape
124+
self.shape: torch.Size = var.shape
125125

126-
self.splits: List[Tuple[int, np.ndarray]] = []
127-
self.split_sizes: List[Tuple[int, np.ndarray]] = []
126+
self.splits: List[Tuple[int, torch.Tensor]] = []
127+
self.split_sizes: List[Tuple[int, torch.Tensor]] = []
128128

129-
split_sizes: List[np.ndarray] = []
129+
split_sizes: List[torch.Tensor] = []
130130

131131
# We split var into smaller blocks. Here we store the metadata to make that split.
132132
for i, d in enumerate(self.shape):
133133
if block_size <= 0 or block_size >= d:
134-
split_sizes.append(np.array([d], dtype=np.int32))
134+
split_sizes.append(torch.tensor([d], dtype=torch.int32))
135135
continue
136136

137137
# d - 1, otherwise split appends a 0-size array.
138138
num_split: int = (d - 1) // block_size
139-
indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size
139+
indices = (torch.arange(num_split, dtype=torch.int32) + 1) * block_size
140140

141-
sizes: np.ndarray = np.ones(num_split + 1, dtype=np.int32) * block_size
141+
sizes: torch.Tensor = torch.full((num_split + 1,), block_size, dtype=torch.int32)
142142
sizes[-1] = d - indices[-1]
143143

144144
self.splits.append((i, indices))
145145
self.split_sizes.append((i, sizes))
146146
split_sizes.append(sizes)
147147

148148
self.num_splits: int = len(split_sizes)
149-
self.pre_conditioner_shapes: List[List[int]] = self.build_pre_conditioner_shapes(
149+
self.pre_conditioner_shapes: List[List[torch.Tensor]] = self.build_pre_conditioner_shapes(
150150
split_sizes, pre_conditioner_type, rank
151151
)
152152

153153
@staticmethod
154154
def build_pre_conditioner_shapes(
155-
split_sizes: List[np.ndarray], pre_conditioner_type: int, rank: int
156-
) -> List[List[int]]:
155+
split_sizes: List[torch.Tensor], pre_conditioner_type: int, rank: int
156+
) -> List[List[torch.Tensor]]:
157157
r"""Build pre-conditioner shapes."""
158-
pre_conditioner_shapes: List[List[int]] = []
158+
pre_conditioner_shapes: List[List[torch.Tensor]] = []
159159
for t in itertools.product(*split_sizes):
160-
t_shape: List[Optional[List[int]]] = [[d, d] for d in t]
160+
t_shape: List[Optional[List[torch.Tensor]]] = [[d, d] for d in t]
161161
if pre_conditioner_type == PreConditionerType.INPUT:
162-
t_shape = t_shape[:-1] + [None]
163-
if pre_conditioner_type == PreConditionerType.OUTPUT:
162+
t_shape[-1] = None
163+
elif pre_conditioner_type == PreConditionerType.OUTPUT:
164164
t_shape = [None] * (rank - 1) + t_shape[-1:]
165165
pre_conditioner_shapes.extend(t_shape)
166166
return pre_conditioner_shapes
167167

168-
def shapes_for_pre_conditioners(self) -> List[List[int]]:
168+
def shapes_for_pre_conditioners(self) -> List[List[torch.Tensor]]:
169169
r"""Get shapes of pre-conditioner."""
170170
return self.pre_conditioner_shapes
171171

@@ -244,7 +244,7 @@ def __init__(
244244

245245
self.w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)
246246

247-
self.original_shape: List[int] = var.shape
247+
self.original_shape: torch.Size = var.shape
248248
self.transformed_shape: List[int] = (
249249
merge_small_dims(self.original_shape, block_size) if shape_interpretation else var.shape
250250
)
@@ -267,7 +267,7 @@ def __init__(
267267
pre_conditioner_type=self.pre_conditioner_type,
268268
)
269269

270-
shapes: List[Optional[List[int]]] = self.partitioner.shapes_for_pre_conditioners()
270+
shapes: List[Optional[List[torch.Tensor]]] = self.partitioner.shapes_for_pre_conditioners()
271271
self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
272272
self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
273273
self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1
@@ -291,7 +291,7 @@ def skip_precondition(self, x: torch.Tensor) -> bool:
291291
dim > self.no_preconditioning_for_layers_with_dim_gt for dim in x.shape
292292
)
293293

294-
def add_statistics(self, grad: torch.Tensor):
294+
def add_statistics(self, grad: torch.Tensor) -> None:
295295
r"""Compute statistics from gradients and add to the correct state entries.
296296
297297
:param grad: torch.Tensor. gradient to compute statistics from.
@@ -302,14 +302,13 @@ def add_statistics(self, grad: torch.Tensor):
302302
reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
303303
partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)
304304

305-
for j in range(len(partitioned_grads)):
306-
partitioned_grad: torch.Tensor = partitioned_grads[j]
305+
for j, partitioned_grad in enumerate(partitioned_grads):
307306
for i in range(self.rank):
308307
axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
309308
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, dims=[axes, axes])
310309
self.statistics[j * self.rank + i].mul_(self.beta2).add_(stat, alpha=self.w2)
311310

312-
def compute_pre_conditioners(self):
311+
def compute_pre_conditioners(self) -> None:
313312
r"""Compute L^{-1/exp} for each stats matrix L.
314313
315314
If `self.use_svd` is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD.
@@ -333,15 +332,15 @@ def compute_pre_conditioners(self):
333332
def precondition_block(
334333
partitioned_grad: torch.Tensor,
335334
should_preconditioned_dims: List[bool],
336-
pre_conditioners_for_grad: List[torch.Tensor],
335+
pre_conditioners_for_grad: Union[List[torch.Tensor], torch.Tensor],
337336
) -> torch.Tensor:
338337
r"""Perform a preconditioning operation on a single gradient block.
339338
340339
Loop invariant: the dimension to be preconditioned is first
341340
We keep all axes in the same cyclic order they were originally.
342341
"""
343342
rank: int = len(partitioned_grad.shape)
344-
roll: Tuple[int, ...] = (*tuple(range(1, rank)), 0)
343+
roll: Tuple[int, ...] = (*range(1, rank), 0)
345344

346345
i: int = 0
347346
for should_precondition_dim in should_preconditioned_dims:
@@ -376,7 +375,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
376375

377376
merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)
378377

379-
return torch.reshape(merged_grad, self.original_shape)
378+
return merged_grad.reshape(self.original_shape)
380379

381380

382381
def build_graft(p: torch.Tensor, graft_type: int, diagonal_eps: float = 1e-10):
@@ -407,7 +406,8 @@ def power_iteration(mat_g: torch.Tensor, num_iters: int = 100) -> torch.Tensor:
407406

408407
for _ in range(num_iters):
409408
torch.mv(mat_g, v, out=mat_v)
410-
v = mat_v.div(torch.linalg.norm(mat_v))
409+
v.copy_(mat_v)
410+
v.div_(torch.linalg.norm(v))
411411

412412
return (v.t() @ mat_g @ v).clamp_min_(1e-16)
413413

@@ -490,7 +490,7 @@ def compute_power_schur_newton(
490490

491491
@torch.no_grad()
492492
def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
493-
r"""Compute G^{-1/p} using a SVD.
493+
r"""Compute G^{-1/p} using SVD.
494494
495495
Calculate SVD on the GPU. Sometimes, SVD on the CPU is faster than GPU, but based on the several experiments,
496496
CUDA seems much faster than on CPU.
@@ -503,14 +503,14 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
503503
return u @ (s.diag() if len(matrix.shape) == 2 else s.diag_embed()) @ vh
504504

505505

506-
def merge_small_dims(shape_to_merge: List[int], max_dim: int) -> List[int]:
506+
def merge_small_dims(shape_to_merge: Union[List[int], torch.Size], max_dim: int) -> List[int]:
507507
r"""Merge small dimensions.
508508
509509
If there are some small dimensions, we collapse them
510510
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
511511
[1, 2, 768, 1, 2048] --> [2, 768, 2048].
512512
513-
:param shape_to_merge: List[int]. Shape to merge small dimensions.
513+
:param shape_to_merge: Union[List[int], torch.Size]. Shape to merge small dimensions.
514514
:param max_dim: int. Maximal dimension of output shape used in merging.
515515
"""
516516
merged_shape: List[int] = []

0 commit comments

Comments
 (0)