diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 355512e9b..90a5ed079 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -19,12 +19,12 @@ def get_optimizer( # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 use_lr = 1.0 if lower_type.endswith('lion'): - optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptLion(params, eps=1e-8, lr=use_lr, **optimizer_params) elif lower_type.endswith('adam'): - optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptLion(params, eps=1e-8, lr=use_lr, **optimizer_params) elif lower_type == 'dadaptation': # backwards compatibility - optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptAdam(params, eps=1e-8, lr=use_lr, **optimizer_params) # warn user that dadaptation is deprecated print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") elif lower_type.startswith("prodigy8bit"): @@ -38,7 +38,19 @@ def get_optimizer( print(f"Using lr {use_lr}") # let net be the neural network you want to train # you can choose weight decay value based on your problem, 0 by default - optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params) + optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-8, **optimizer_params) + elif lower_type.startswith("adamw_fp8"): + from toolkit.optimizers.adamw_fp8 import AdamWFP8 + print("Using adamw_fp8") + use_lr = learning_rate + + optimizer = AdamWFP8(params, lr=use_lr, eps=1e-8, **optimizer_params) + elif lower_type.startswith("adamw_bf16"): + from toolkit.optimizers.adamw_bf16 import AdamWBF16 + print("Using adamw_bf16") + use_lr = learning_rate + + optimizer = AdamWBF16(params, lr=use_lr, eps=1e-8, **optimizer_params) elif lower_type.startswith("prodigy"): from prodigyopt import Prodigy @@ -51,32 +63,32 @@ def get_optimizer( print(f"Using lr {use_lr}") # let net be the neural network you want to train # you can choose weight decay value based on your problem, 0 by default - optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) + optimizer = Prodigy(params, lr=use_lr, eps=1e-8, use_bias_correction=True, d0=5e-5, d_coef=1.0, safeguard_warmup=True, **optimizer_params) elif lower_type == "adam8": from toolkit.optimizers.adam8bit import Adam8bit - optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params) elif lower_type == "adamw8": from toolkit.optimizers.adam8bit import Adam8bit - optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params) + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-8, decouple=True, **optimizer_params) elif lower_type.endswith("8bit"): import bitsandbytes if lower_type == "adam8bit": - return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params) if lower_type == "ademamix8bit": - return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params) elif lower_type == "adamw8bit": - return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-8, **optimizer_params) elif lower_type == "lion8bit": return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') elif lower_type == 'adam': - optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-8, **optimizer_params) elif lower_type == 'adamw': - optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-8, **optimizer_params) elif lower_type == 'lion': try: from lion_pytorch import Lion