Skip to content

Commit 474510f

Browse files
authored
Merge pull request #260 from kozistr/refactor/code
[Update] Improve the performance
2 parents 22f994b + 4666678 commit 474510f

File tree

5 files changed

+70
-41
lines changed

5 files changed

+70
-41
lines changed

docs/changelogs/v3.1.0.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
* you can use by `optimizer = load_optimizer('q_galore_adamw8bit')`
1010
* Support more bnb optimizers. (#258)
1111
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.
12+
* Improve `power_iteration()` speed up to 40%. (#259)
13+
* Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260)
1214

1315
### Refactor
1416

15-
* Refactor `AdamMini`. (#258)
17+
* Refactor `AdamMini` optimizer. (#258)
1618
* Deprecate optional dependency, `bitsandbytes`. (#258)
1719
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)
20+
* Refactor `shampoo_utils.py`. (#259)
1821

1922
### Bug
2023

poetry.lock

Lines changed: 49 additions & 27 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/optimizer/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99
from torch.distributed import all_reduce
10-
from torch.nn import functional as f
10+
from torch.nn.functional import cosine_similarity
1111
from torch.nn.modules.batchnorm import _BatchNorm
1212
from torch.nn.utils import clip_grad_norm_
1313

@@ -62,7 +62,7 @@ def to_real(x: torch.Tensor) -> torch.Tensor:
6262
return x.real if torch.is_complex(x) else x
6363

6464

65-
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8):
65+
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> None:
6666
r"""Normalize gradient with stddev.
6767
6868
:param x: torch.Tensor. gradient.
@@ -119,7 +119,7 @@ def cosine_similarity_by_view(
119119
"""
120120
x = view_func(x)
121121
y = view_func(y)
122-
return f.cosine_similarity(x, y, dim=1, eps=eps).abs_()
122+
return cosine_similarity(x, y, dim=1, eps=eps).abs_()
123123

124124

125125
def clip_grad_norm(
@@ -315,6 +315,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
315315
return x
316316

317317

318+
@torch.no_grad()
318319
def reg_noise(
319320
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
320321
) -> Union[torch.Tensor, float]:
@@ -332,11 +333,14 @@ def reg_noise(
332333
reg_coef: float = 0.5 / (eta * num_data)
333334
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)
334335

335-
loss = 0
336-
for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True):
337-
reg = torch.sub(param1, param2).pow_(2) * reg_coef
338-
noise1 = param1 * torch.randn_like(param1) * noise_coef
339-
noise2 = param2 * torch.randn_like(param2) * noise_coef
340-
loss += torch.sum(reg - noise1 - noise2)
336+
loss = torch.tensor(0.0, device=next(network1.parameters()).device)
337+
338+
for param1, param2 in zip(network1.parameters(), network2.parameters()):
339+
reg = (param1 - param2).pow_(2).mul_(reg_coef).sum()
340+
341+
noise = param1 * torch.randn_like(param1)
342+
noise.add_(param2 * torch.randn_like(param2))
343+
344+
loss.add_(reg - noise.mul_(noise_coef).sum())
341345

342346
return loss

requirements-dev.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
2222
platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
2323
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
2424
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
25-
pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
26-
ruff==0.5.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
27-
sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
25+
pytest==8.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
26+
ruff==0.5.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
27+
sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
2828
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
2929
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
3030
torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mkl==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and pl
99
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
1010
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
1111
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
12-
sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
12+
sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
1313
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
1414
torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
1515
typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"

0 commit comments

Comments
 (0)