You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102
+
Note: This parameter is not supported in AdamW8bit and must be False.
101
103
optim_bits (`int`, defaults to 32):
102
104
The number of bits of the optimizer state.
105
+
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
103
106
args (`object`, defaults to `None`):
104
107
An object with additional arguments.
105
108
min_8bit_size (`int`, defaults to 4096):
@@ -111,14 +114,23 @@ def __init__(
111
114
is_paged (`bool`, defaults to `False`):
112
115
Whether the optimizer is a paged optimizer or not.
113
116
"""
117
+
# Validate unsupported parameters
118
+
ifamsgrad:
119
+
raiseValueError("AdamW8bit does not support amsgrad=True")
120
+
121
+
ifoptim_bits!=32:
122
+
# We allow the default value of 32 to maintain compatibility with the function signature,
123
+
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
124
+
raiseValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")
125
+
114
126
super().__init__(
115
127
"adam",
116
128
params,
117
129
lr,
118
130
betas,
119
131
eps,
120
132
weight_decay,
121
-
8,
133
+
8,# Hardcoded to 8 bits
122
134
args,
123
135
min_8bit_size,
124
136
percentile_clipping,
@@ -147,7 +159,7 @@ def __init__(
147
159
32-bit AdamW optimizer.
148
160
149
161
Arguments:
150
-
params (`torch.tensor`):
162
+
params (`torch.Tensor`):
151
163
The input parameters to optimize.
152
164
lr (`float`, defaults to 1e-3):
153
165
The learning rate.
@@ -207,7 +219,7 @@ def __init__(
207
219
Paged AdamW optimizer.
208
220
209
221
Arguments:
210
-
params (`torch.tensor`):
222
+
params (`torch.Tensor`):
211
223
The input parameters to optimize.
212
224
lr (`float`, defaults to 1e-3):
213
225
The learning rate.
@@ -229,8 +241,6 @@ def __init__(
229
241
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
230
242
block_wise (`bool`, defaults to `True`):
231
243
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
232
-
is_paged (`bool`, defaults to `False`):
233
-
Whether the optimizer is a paged optimizer or not.
234
244
"""
235
245
super().__init__(
236
246
"adam",
@@ -267,7 +277,7 @@ def __init__(
267
277
Paged 8-bit AdamW optimizer.
268
278
269
279
Arguments:
270
-
params (`torch.tensor`):
280
+
params (`torch.Tensor`):
271
281
The input parameters to optimize.
272
282
lr (`float`, defaults to 1e-3):
273
283
The learning rate.
@@ -279,8 +289,10 @@ def __init__(
279
289
The weight decay value for the optimizer.
280
290
amsgrad (`bool`, defaults to `False`):
281
291
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
292
+
Note: This parameter is not supported in PagedAdamW8bit and must be False.
282
293
optim_bits (`int`, defaults to 32):
283
294
The number of bits of the optimizer state.
295
+
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
284
296
args (`object`, defaults to `None`):
285
297
An object with additional arguments.
286
298
min_8bit_size (`int`, defaults to 4096):
@@ -289,17 +301,24 @@ def __init__(
289
301
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
290
302
block_wise (`bool`, defaults to `True`):
291
303
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
292
-
is_paged (`bool`, defaults to `False`):
293
-
Whether the optimizer is a paged optimizer or not.
294
304
"""
305
+
# Validate unsupported parameters
306
+
ifamsgrad:
307
+
raiseValueError("PagedAdamW8bit does not support amsgrad=True")
308
+
309
+
ifoptim_bits!=32:
310
+
# We allow the default value of 32 to maintain compatibility with the function signature,
311
+
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
312
+
raiseValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")
313
+
295
314
super().__init__(
296
315
"adam",
297
316
params,
298
317
lr,
299
318
betas,
300
319
eps,
301
320
weight_decay,
302
-
8,
321
+
8,# Hardcoded to 8 bits
303
322
args,
304
323
min_8bit_size,
305
324
percentile_clipping,
@@ -327,7 +346,7 @@ def __init__(
327
346
Paged 32-bit AdamW optimizer.
328
347
329
348
Arguments:
330
-
params (`torch.tensor`):
349
+
params (`torch.Tensor`):
331
350
The input parameters to optimize.
332
351
lr (`float`, defaults to 1e-3):
333
352
The learning rate.
@@ -349,8 +368,6 @@ def __init__(
349
368
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
350
369
block_wise (`bool`, defaults to `True`):
351
370
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
352
-
is_paged (`bool`, defaults to `False`):
353
-
Whether the optimizer is a paged optimizer or not.
0 commit comments