22#
33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
5+
56from bitsandbytes .optim .optimizer import Optimizer2State
67
78
@@ -98,8 +99,10 @@ def __init__(
9899 The weight decay value for the optimizer.
99100 amsgrad (`bool`, defaults to `False`):
100101 Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+ Note: This parameter is not supported in AdamW8bit and must be False.
101103 optim_bits (`int`, defaults to 32):
102104 The number of bits of the optimizer state.
105+ Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
103106 args (`object`, defaults to `None`):
104107 An object with additional arguments.
105108 min_8bit_size (`int`, defaults to 4096):
@@ -111,14 +114,23 @@ def __init__(
111114 is_paged (`bool`, defaults to `False`):
112115 Whether the optimizer is a paged optimizer or not.
113116 """
117+ # Validate unsupported parameters
118+ if amsgrad :
119+ raise ValueError ("AdamW8bit does not support amsgrad=True" )
120+
121+ if optim_bits != 32 :
122+ # We allow the default value of 32 to maintain compatibility with the function signature,
123+ # but any other value is invalid since AdamW8bit always uses 8-bit optimization
124+ raise ValueError ("AdamW8bit only supports optim_bits=32 (default value for compatibility)" )
125+
114126 super ().__init__ (
115127 "adam" ,
116128 params ,
117129 lr ,
118130 betas ,
119131 eps ,
120132 weight_decay ,
121- 8 ,
133+ 8 , # Hardcoded to 8 bits
122134 args ,
123135 min_8bit_size ,
124136 percentile_clipping ,
@@ -279,8 +291,10 @@ def __init__(
279291 The weight decay value for the optimizer.
280292 amsgrad (`bool`, defaults to `False`):
281293 Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
294+ Note: This parameter is not supported in PagedAdamW8bit and must be False.
282295 optim_bits (`int`, defaults to 32):
283296 The number of bits of the optimizer state.
297+ Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
284298 args (`object`, defaults to `None`):
285299 An object with additional arguments.
286300 min_8bit_size (`int`, defaults to 4096):
@@ -292,14 +306,23 @@ def __init__(
292306 is_paged (`bool`, defaults to `False`):
293307 Whether the optimizer is a paged optimizer or not.
294308 """
309+ # Validate unsupported parameters
310+ if amsgrad :
311+ raise ValueError ("PagedAdamW8bit does not support amsgrad=True" )
312+
313+ if optim_bits != 32 :
314+ # We allow the default value of 32 to maintain compatibility with the function signature,
315+ # but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
316+ raise ValueError ("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)" )
317+
295318 super ().__init__ (
296319 "adam" ,
297320 params ,
298321 lr ,
299322 betas ,
300323 eps ,
301324 weight_decay ,
302- 8 ,
325+ 8 , # Hardcoded to 8 bits
303326 args ,
304327 min_8bit_size ,
305328 percentile_clipping ,
0 commit comments