Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions bitsandbytes/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from bitsandbytes.optim.optimizer import Optimizer2State


Expand Down Expand Up @@ -100,8 +99,10 @@ def __init__(
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in Adam8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
Expand All @@ -113,14 +114,23 @@ def __init__(
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("Adam8bit does not support amsgrad=True")

if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since Adam8bit always uses 8-bit optimization
raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)")

super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
Expand Down Expand Up @@ -283,8 +293,10 @@ def __init__(
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in PagedAdam8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
Expand All @@ -296,14 +308,23 @@ def __init__(
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("PagedAdam8bit does not support amsgrad=True")

if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)")

super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
Expand Down
27 changes: 25 additions & 2 deletions bitsandbytes/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from bitsandbytes.optim.optimizer import Optimizer2State


Expand Down Expand Up @@ -98,8 +99,10 @@ def __init__(
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in AdamW8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
Expand All @@ -111,14 +114,23 @@ def __init__(
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("AdamW8bit does not support amsgrad=True")

if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")

super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
Expand Down Expand Up @@ -279,8 +291,10 @@ def __init__(
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in PagedAdamW8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
Expand All @@ -292,14 +306,23 @@ def __init__(
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("PagedAdamW8bit does not support amsgrad=True")

if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")

super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
Expand Down