Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions toolkit/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def get_optimizer(
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
use_lr = 1.0
if lower_type.endswith('lion'):
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptLion(params, eps=1e-8, lr=use_lr, **optimizer_params)
elif lower_type.endswith('adam'):
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptLion(params, eps=1e-8, lr=use_lr, **optimizer_params)
elif lower_type == 'dadaptation':
# backwards compatibility
optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptAdam(params, eps=1e-8, lr=use_lr, **optimizer_params)
# warn user that dadaptation is deprecated
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
elif lower_type.startswith("prodigy8bit"):
Expand All @@ -38,7 +38,19 @@ def get_optimizer(
print(f"Using lr {use_lr}")
# let net be the neural network you want to train
# you can choose weight decay value based on your problem, 0 by default
optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params)
optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-8, **optimizer_params)
elif lower_type.startswith("adamw_fp8"):
from toolkit.optimizers.adamw_fp8 import AdamWFP8
print("Using adamw_fp8")
use_lr = learning_rate

optimizer = AdamWFP8(params, lr=use_lr, eps=1e-8, **optimizer_params)
elif lower_type.startswith("adamw_bf16"):
from toolkit.optimizers.adamw_bf16 import AdamWBF16
print("Using adamw_bf16")
use_lr = learning_rate

optimizer = AdamWBF16(params, lr=use_lr, eps=1e-8, **optimizer_params)
elif lower_type.startswith("prodigy"):
from prodigyopt import Prodigy

Expand All @@ -51,32 +63,32 @@ def get_optimizer(
print(f"Using lr {use_lr}")
# let net be the neural network you want to train
# you can choose weight decay value based on your problem, 0 by default
optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params)
optimizer = Prodigy(params, lr=use_lr, eps=1e-8, use_bias_correction=True, d0=5e-5, d_coef=1.0, safeguard_warmup=True, **optimizer_params)
elif lower_type == "adam8":
from toolkit.optimizers.adam8bit import Adam8bit

optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
optimizer = Adam8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params)
elif lower_type == "adamw8":
from toolkit.optimizers.adam8bit import Adam8bit

optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params)
optimizer = Adam8bit(params, lr=learning_rate, eps=1e-8, decouple=True, **optimizer_params)
elif lower_type.endswith("8bit"):
import bitsandbytes

if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params)
if lower_type == "ademamix8bit":
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params)
elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
elif lower_type == 'adam':
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-8, **optimizer_params)
elif lower_type == 'adamw':
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-8, **optimizer_params)
elif lower_type == 'lion':
try:
from lion_pytorch import Lion
Expand Down