Skip to content

Commit 61db085

Browse files
authored
doc fix signature for 8-bit optim (#1660)
* doc fix signature for 8-bit optim * required changes * precommit
1 parent df73d3e commit 61db085

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

bitsandbytes/optim/adam.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
from bitsandbytes.optim.optimizer import Optimizer2State
87

98

@@ -100,8 +99,10 @@ def __init__(
10099
The weight decay value for the optimizer.
101100
amsgrad (`bool`, defaults to `False`):
102101
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+
Note: This parameter is not supported in Adam8bit and must be False.
103103
optim_bits (`int`, defaults to 32):
104104
The number of bits of the optimizer state.
105+
Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
105106
args (`object`, defaults to `None`):
106107
An object with additional arguments.
107108
min_8bit_size (`int`, defaults to 4096):
@@ -113,14 +114,23 @@ def __init__(
113114
is_paged (`bool`, defaults to `False`):
114115
Whether the optimizer is a paged optimizer or not.
115116
"""
117+
# Validate unsupported parameters
118+
if amsgrad:
119+
raise ValueError("Adam8bit does not support amsgrad=True")
120+
121+
if optim_bits != 32:
122+
# We allow the default value of 32 to maintain compatibility with the function signature,
123+
# but any other value is invalid since Adam8bit always uses 8-bit optimization
124+
raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)")
125+
116126
super().__init__(
117127
"adam",
118128
params,
119129
lr,
120130
betas,
121131
eps,
122132
weight_decay,
123-
8,
133+
8, # Hardcoded to 8 bits
124134
args,
125135
min_8bit_size,
126136
percentile_clipping,
@@ -283,8 +293,10 @@ def __init__(
283293
The weight decay value for the optimizer.
284294
amsgrad (`bool`, defaults to `False`):
285295
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
296+
Note: This parameter is not supported in PagedAdam8bit and must be False.
286297
optim_bits (`int`, defaults to 32):
287298
The number of bits of the optimizer state.
299+
Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
288300
args (`object`, defaults to `None`):
289301
An object with additional arguments.
290302
min_8bit_size (`int`, defaults to 4096):
@@ -296,14 +308,23 @@ def __init__(
296308
is_paged (`bool`, defaults to `False`):
297309
Whether the optimizer is a paged optimizer or not.
298310
"""
311+
# Validate unsupported parameters
312+
if amsgrad:
313+
raise ValueError("PagedAdam8bit does not support amsgrad=True")
314+
315+
if optim_bits != 32:
316+
# We allow the default value of 32 to maintain compatibility with the function signature,
317+
# but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
318+
raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)")
319+
299320
super().__init__(
300321
"adam",
301322
params,
302323
lr,
303324
betas,
304325
eps,
305326
weight_decay,
306-
8,
327+
8, # Hardcoded to 8 bits
307328
args,
308329
min_8bit_size,
309330
percentile_clipping,

bitsandbytes/optim/adamw.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
56
from bitsandbytes.optim.optimizer import Optimizer2State
67

78

@@ -98,8 +99,10 @@ def __init__(
9899
The weight decay value for the optimizer.
99100
amsgrad (`bool`, defaults to `False`):
100101
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+
Note: This parameter is not supported in AdamW8bit and must be False.
101103
optim_bits (`int`, defaults to 32):
102104
The number of bits of the optimizer state.
105+
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
103106
args (`object`, defaults to `None`):
104107
An object with additional arguments.
105108
min_8bit_size (`int`, defaults to 4096):
@@ -111,14 +114,23 @@ def __init__(
111114
is_paged (`bool`, defaults to `False`):
112115
Whether the optimizer is a paged optimizer or not.
113116
"""
117+
# Validate unsupported parameters
118+
if amsgrad:
119+
raise ValueError("AdamW8bit does not support amsgrad=True")
120+
121+
if optim_bits != 32:
122+
# We allow the default value of 32 to maintain compatibility with the function signature,
123+
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
124+
raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")
125+
114126
super().__init__(
115127
"adam",
116128
params,
117129
lr,
118130
betas,
119131
eps,
120132
weight_decay,
121-
8,
133+
8, # Hardcoded to 8 bits
122134
args,
123135
min_8bit_size,
124136
percentile_clipping,
@@ -279,8 +291,10 @@ def __init__(
279291
The weight decay value for the optimizer.
280292
amsgrad (`bool`, defaults to `False`):
281293
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
294+
Note: This parameter is not supported in PagedAdamW8bit and must be False.
282295
optim_bits (`int`, defaults to 32):
283296
The number of bits of the optimizer state.
297+
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
284298
args (`object`, defaults to `None`):
285299
An object with additional arguments.
286300
min_8bit_size (`int`, defaults to 4096):
@@ -292,14 +306,23 @@ def __init__(
292306
is_paged (`bool`, defaults to `False`):
293307
Whether the optimizer is a paged optimizer or not.
294308
"""
309+
# Validate unsupported parameters
310+
if amsgrad:
311+
raise ValueError("PagedAdamW8bit does not support amsgrad=True")
312+
313+
if optim_bits != 32:
314+
# We allow the default value of 32 to maintain compatibility with the function signature,
315+
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
316+
raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")
317+
295318
super().__init__(
296319
"adam",
297320
params,
298321
lr,
299322
betas,
300323
eps,
301324
weight_decay,
302-
8,
325+
8, # Hardcoded to 8 bits
303326
args,
304327
min_8bit_size,
305328
percentile_clipping,

0 commit comments

Comments
 (0)