Skip to content

Commit 00fbae0

Browse files
authored
Merge pull request #349 from kozistr/update/v3.4.1
[Release] v3.4.1
2 parents 5f4e62f + ba341db commit 00fbae0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+160
-185
lines changed

docs/changelogs/v3.4.1.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* change default beta1, beta2 to 0.95 and 0.98 respectively
1717
* Skip adding `Lookahead` wrapper in case of `Ranger*` optimizers, which already have it in `create_optimizer()`. (#340)
1818
* Improved optimizer visualization. (#345)
19+
* Rename `pytorch_optimizer.optimizer.gc` to `pytorch_optimizer.optimizer.gradient_centralization` to avoid possible conflict with Python built-in function `gc`. (#349)
1920

2021
### Bug
2122

-483 Bytes
Loading
122 Bytes
Loading

examples/visualize_optimizers.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
2-
import warnings
32
from functools import partial
43
from pathlib import Path
5-
from typing import Callable, Dict, Tuple, Union
4+
from typing import Callable, Dict, List, Tuple, Union
5+
from warnings import filterwarnings
66

77
import numpy as np
88
import torch
@@ -14,23 +14,23 @@
1414
from pytorch_optimizer.optimizer import OPTIMIZERS
1515
from pytorch_optimizer.optimizer.alig import l2_projection
1616

17-
warnings.filterwarnings('ignore', category=UserWarning)
17+
filterwarnings('ignore', category=UserWarning)
1818

1919
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'alig') # BUG: fix `alig`, invalid .__name__
2020
OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo', 'adalomo', 'adammini')
2121
OPTIMIZERS_GRAPH_NEEDED = ('adahessian', 'sophiah')
2222
OPTIMIZERS_CLOSURE_NEEDED = ('alig', 'bsam')
23-
EVAL_PER_HYPYPERPARAM = 540
24-
OPTIMIZATION_STEPS = 300
25-
TESTING_OPTIMIZATION_STEPS = 650
26-
DIFFICULT_RASTRIGIN = False
27-
USE_AVERAGE_LOSS_PENALTY = True
28-
AVERAGE_LOSS_PENALTY_FACTOR = 1.0
29-
SEARCH_SEED = 42
30-
LOSS_MIN_TRESH = 0
31-
32-
default_search_space = {'lr': hp.uniform('lr', 0, 2)}
33-
special_search_spaces = {
23+
EVAL_PER_HYPERPARAM: int = 540
24+
OPTIMIZATION_STEPS: int = 300
25+
TESTING_OPTIMIZATION_STEPS: int = 650
26+
DIFFICULT_RASTRIGIN: bool = False
27+
USE_AVERAGE_LOSS_PENALTY: bool = True
28+
AVERAGE_LOSS_PENALTY_FACTOR: float = 1.0
29+
SEARCH_SEED: int = 42
30+
LOSS_MIN_THRESHOLD: float = 0.0
31+
32+
DEFAULT_SEARCH_SPACES = {'lr': hp.uniform('lr', 0, 2)}
33+
SPECIAL_SEARCH_SPACES = {
3434
'adafactor': {'lr': hp.uniform('lr', 0, 10)},
3535
'adams': {'lr': hp.uniform('lr', 0, 10)},
3636
'dadaptadagrad': {'lr': hp.uniform('lr', 0, 10)},
@@ -170,7 +170,7 @@ def execute_steps(
170170
optimizer_class: torch.optim.Optimizer,
171171
optimizer_config: Dict,
172172
num_iters: int = 500,
173-
) -> torch.Tensor:
173+
) -> Tuple[torch.Tensor, List[float]]:
174174
"""
175175
Execute optimization steps for a given configuration.
176176
@@ -201,7 +201,6 @@ def closure() -> float:
201201

202202
return closure
203203

204-
# Initialize the model and optimizer
205204
model = Model(func, initial_state)
206205
parameters = list(model.parameters())
207206
optimizer_name: str = optimizer_class.__name__.lower()
@@ -218,30 +217,25 @@ def closure() -> float:
218217
elif optimizer_name == 'bsam':
219218
optimizer_config['num_data'] = 1
220219

221-
# Special initialization for memory-efficient optimizers
222220
if optimizer_name in OPTIMIZERS_MODEL_INPUT_NEEDED:
223221
optimizer = optimizer_class(model, **optimizer_config)
224222
else:
225223
optimizer = optimizer_class(parameters, **optimizer_config)
226224

227-
# Track optimization path
228-
losses = []
229225
steps = torch.zeros((2, num_iters + 1), dtype=torch.float32)
230226
steps[:, 0] = model.x.detach()
231227

228+
losses = []
232229
for i in range(1, num_iters + 1):
233230
optimizer.zero_grad()
231+
234232
loss = model()
235233
losses.append(loss.item())
236234

237-
# Special handling for second-order optimizers
238-
create_graph = optimizer_name in OPTIMIZERS_GRAPH_NEEDED
239-
loss.backward(create_graph=create_graph)
235+
loss.backward(create_graph=optimizer_name in OPTIMIZERS_GRAPH_NEEDED)
240236

241-
# Gradient clipping for stability
242237
nn.utils.clip_grad_norm_(parameters, 1.0)
243238

244-
# Closure required for certain optimizers
245239
closure = create_closure(loss) if optimizer_name in OPTIMIZERS_CLOSURE_NEEDED else None
246240
optimizer.step(closure)
247241

@@ -279,25 +273,19 @@ def objective(
279273
- A penalty for boundary violations.
280274
- An optional penalty for higher average loss during optimization (if enabled).
281275
"""
282-
# Execute optimization steps and get losses
283-
steps, losses = execute_steps( # Modified to unpack losses
284-
criterion, initial_state, optimizer_class, params, num_iters
285-
)
276+
steps, losses = execute_steps(criterion, initial_state, optimizer_class, params, num_iters)
286277

287-
# Calculate boundary violations
288278
x_min_violation = torch.clamp(x_bounds[0] - steps[0], min=0).max()
289279
x_max_violation = torch.clamp(steps[0] - x_bounds[1], min=0).max()
290280
y_min_violation = torch.clamp(y_bounds[0] - steps[1], min=0).max()
291281
y_max_violation = torch.clamp(steps[1] - y_bounds[1], min=0).max()
292282
total_violation = x_min_violation + x_max_violation + y_min_violation + y_max_violation
293283

294-
# Calculate average loss penalty
295-
avg_loss = sum(losses) / len(losses) if losses else 0.0
296284
penalty = 75 * total_violation.item()
297285
if USE_AVERAGE_LOSS_PENALTY:
286+
avg_loss: float = sum(losses) / len(losses) if losses else 0.0
298287
penalty += avg_loss * AVERAGE_LOSS_PENALTY_FACTOR
299288

300-
# Calculate final distance to minimum
301289
final_position = steps[:, -1]
302290
final_distance = ((final_position[0] - minimum[0]) ** 2 + (final_position[1] - minimum[1]) ** 2).item()
303291

@@ -309,7 +297,7 @@ def plot_function(
309297
optimization_steps: torch.Tensor,
310298
output_path: Path,
311299
optimizer_name: str,
312-
params: dict,
300+
params: Dict,
313301
x_range: Tuple[float, float],
314302
y_range: Tuple[float, float],
315303
minimum: Tuple[float, float],
@@ -335,34 +323,29 @@ def plot_function(
335323
fig = plt.figure(figsize=(8, 8))
336324
ax = fig.add_subplot(1, 1, 1)
337325

338-
# Plot function contours and optimization path
339326
ax.contour(x_grid.numpy(), y_grid.numpy(), z.numpy(), 20, cmap='jet')
340327
ax.plot(optimization_steps[0], optimization_steps[1], color='r', marker='x', markersize=3)
341328

342-
# Mark global minimum and final position
343329
plt.plot(*minimum, 'gD', label='Global Minimum')
344330
plt.plot(optimization_steps[0, -1], optimization_steps[1, -1], 'bD', label='Final Position')
345331

346-
ax.set_title(
347-
f'{func.__name__} func: {optimizer_name} with {TESTING_OPTIMIZATION_STEPS} iterations\n{
348-
", ".join(f"{k}={round(v, 4)}" for k, v in params.items())
349-
}'
350-
)
332+
config: str = ', '.join(f'{k}={round(v, 4)}' for k, v in params.items())
333+
ax.set_title(f'{func.__name__} func: {optimizer_name} with {TESTING_OPTIMIZATION_STEPS} iterations\n{config}')
351334
plt.legend()
352335
plt.savefig(str(output_path))
353336
plt.close()
354337

355338

356339
def execute_experiments(
357-
optimizers: list,
340+
optimizers: List,
358341
func: Callable,
359342
initial_state: Tuple[float, float],
360343
output_dir: Path,
361344
experiment_name: str,
362345
x_range: Tuple[float, float],
363346
y_range: Tuple[float, float],
364347
minimum: Tuple[float, float],
365-
seed: int = 42,
348+
seed: int = SEARCH_SEED,
366349
) -> None:
367350
"""
368351
Run optimization experiments for multiple optimizers.
@@ -382,15 +365,14 @@ def execute_experiments(
382365
optimizer_name = optimizer_class.__name__
383366
output_path = output_dir / f'{experiment_name}_{optimizer_name}.png'
384367
if output_path.exists():
385-
continue # Skip already generated plots
368+
continue
386369

387370
print( # noqa: T201
388371
f'({i}/{len(optimizers)}) Processing {optimizer_name}... (Params to tune: {", ".join(search_space.keys())})' # noqa: E501
389372
)
390373

391-
# Select hyperparameter search space
392-
num_hyperparams = len(search_space)
393-
max_evals = EVAL_PER_HYPYPERPARAM * num_hyperparams # Scale evaluations based on hyperparameter count
374+
num_hyperparams: int = len(search_space)
375+
max_evals: int = EVAL_PER_HYPERPARAM * num_hyperparams
394376

395377
objective_fn = partial(
396378
objective,
@@ -402,43 +384,38 @@ def execute_experiments(
402384
y_bounds=y_range,
403385
num_iters=OPTIMIZATION_STEPS,
404386
)
387+
405388
try:
406389
best_params = fmin(
407390
fn=objective_fn,
408391
space=search_space,
409392
algo=tpe.suggest,
410393
max_evals=max_evals,
411-
loss_threshold=LOSS_MIN_TRESH,
394+
loss_threshold=LOSS_MIN_THRESHOLD,
412395
rstate=np.random.default_rng(seed),
413396
)
414397
except AllTrialsFailed:
415398
print(f'⚠️ {optimizer_name} failed to optimize {func.__name__}') # noqa: T201
416399
continue
417400

418-
# Run final optimization with best parameters
419-
steps, _ = execute_steps( # Modified to ignore losses
420-
func, initial_state, optimizer_class, best_params, TESTING_OPTIMIZATION_STEPS
421-
)
401+
steps, _ = execute_steps(func, initial_state, optimizer_class, best_params, TESTING_OPTIMIZATION_STEPS)
422402

423-
# Generate and save visualization
424403
plot_function(func, steps, output_path, optimizer_name, best_params, x_range, y_range, minimum)
425404

426405

427406
def main():
428-
"""Main execution routine for optimization experiments."""
429407
np.random.seed(SEARCH_SEED)
430408
torch.manual_seed(SEARCH_SEED)
409+
431410
output_dir = Path('.') / 'docs' / 'visualizations'
432411
output_dir.mkdir(parents=True, exist_ok=True)
433412

434-
# Prepare the list of optimizers and their search spaces
435413
optimizers = [
436-
(optimizer, special_search_spaces.get(optimizer_name, default_search_space))
414+
(optimizer, SPECIAL_SEARCH_SPACES.get(optimizer_name, DEFAULT_SEARCH_SPACES))
437415
for optimizer_name, optimizer in OPTIMIZERS.items()
438416
if optimizer_name not in OPTIMIZERS_IGNORE
439417
]
440418

441-
# Run experiments for the Rastrigin function
442419
print('Executing Rastrigin experiments...') # noqa: T201
443420
execute_experiments(
444421
optimizers,
@@ -452,7 +429,6 @@ def main():
452429
seed=SEARCH_SEED,
453430
)
454431

455-
# Run experiments for the Rosenbrock function
456432
print('Executing Rosenbrock experiments...') # noqa: T201
457433
execute_experiments(
458434
optimizers,

poetry.lock

Lines changed: 20 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.4.0"
3+
version = "3.4.1"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -97,7 +97,7 @@ select = [
9797
"TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q"
9898
]
9999
ignore = [
100-
"A005", "B905",
100+
"B905",
101101
"D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413",
102102
"PLR0912", "PLR0913", "PLR0915", "PLR2004",
103103
"Q003", "ARG002",

pytorch_optimizer/base/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.optim import Optimizer
77

88
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
9-
from pytorch_optimizer.base.types import (
9+
from pytorch_optimizer.base.type import (
1010
BETAS,
1111
CLOSURE,
1212
DEFAULTS,
File renamed without changes.

pytorch_optimizer/loss/dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.functional import logsigmoid, one_hot
55
from torch.nn.modules.loss import _Loss
66

7-
from pytorch_optimizer.base.types import CLASS_MODE
7+
from pytorch_optimizer.base.type import CLASS_MODE
88

99

1010
def soft_dice_score(

pytorch_optimizer/loss/jaccard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.functional import logsigmoid, one_hot
55
from torch.nn.modules.loss import _Loss
66

7-
from pytorch_optimizer.base.types import CLASS_MODE
7+
from pytorch_optimizer.base.type import CLASS_MODE
88

99

1010
def soft_jaccard_score(

0 commit comments

Comments
 (0)