You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: bitsandbytes/nn/modules.py
+61-22Lines changed: 61 additions & 22 deletions
Original file line number
Diff line number
Diff line change
@@ -21,16 +21,7 @@
21
21
22
22
classStableEmbedding(torch.nn.Embedding):
23
23
"""
24
-
Custom embedding layer designed for stable training in NLP tasks. The stable
25
-
embedding layer improves stability during optimization for models with word
26
-
embeddings, addressing issues related to the non-uniform distribution of input
27
-
tokens.
28
-
29
-
This stable embedding layer is initialized with Xavier uniform initialization,
30
-
followed by layer normalization. It is designed to support aggressive quantization,
31
-
addressing extreme gradient variations in non-uniform input distributions. The
32
-
stability of training is enhanced by using 32-bit optimizer states specifically
33
-
for this layer.
24
+
Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.
34
25
35
26
Example:
36
27
@@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding):
47
38
```
48
39
49
40
Attributes:
50
-
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding.
41
+
norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.
51
42
52
43
Methods:
53
44
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
54
45
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
Copy file name to clipboardExpand all lines: bitsandbytes/optim/adagrad.py
+81Lines changed: 81 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -20,6 +20,33 @@ def __init__(
20
20
percentile_clipping=100,
21
21
block_wise=True,
22
22
):
23
+
"""
24
+
Base Adagrad optimizer.
25
+
26
+
Arguments:
27
+
params (`torch.tensor`):
28
+
The input parameters to optimize.
29
+
lr (`float`, defaults to 1e-2):
30
+
The learning rate.
31
+
lr_decay (`int`, defaults to 0):
32
+
The learning rate decay.
33
+
weight_decay (`float`, defaults to 0.0):
34
+
The weight decay value for the optimizer.
35
+
initial_accumulator_value (`int`, defaults to 0):
36
+
The initial momemtum values.
37
+
eps (`float`, defaults to 1e-10):
38
+
The epsilon value prevents division by zero in the optimizer.
39
+
optim_bits (`int`, defaults to 32):
40
+
The number of bits of the optimizer state.
41
+
args (`dict`, defaults to `None`):
42
+
A dictionary with additional arguments.
43
+
min_8bit_size (`int`, defaults to 4096):
44
+
The minimum number of elements of the parameter tensors for 8-bit optimization.
45
+
percentile_clipping (`int`, defaults to 100):
46
+
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
47
+
block_wise (`bool`, defaults to `True`):
48
+
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
49
+
"""
23
50
ifnot0.0<=lr:
24
51
raiseValueError(f"Invalid learning rate: {lr}")
25
52
ifnot0.0<=weight_decay:
@@ -62,6 +89,33 @@ def __init__(
62
89
percentile_clipping=100,
63
90
block_wise=True,
64
91
):
92
+
"""
93
+
8-bit Adagrad optimizer.
94
+
95
+
Arguments:
96
+
params (`torch.tensor`):
97
+
The input parameters to optimize.
98
+
lr (`float`, defaults to 1e-2):
99
+
The learning rate.
100
+
lr_decay (`int`, defaults to 0):
101
+
The learning rate decay.
102
+
weight_decay (`float`, defaults to 0.0):
103
+
The weight decay value for the optimizer.
104
+
initial_accumulator_value (`int`, defaults to 0):
105
+
The initial momemtum values.
106
+
eps (`float`, defaults to 1e-10):
107
+
The epsilon value prevents division by zero in the optimizer.
108
+
optim_bits (`int`, defaults to 8):
109
+
The number of bits of the optimizer state.
110
+
args (`dict`, defaults to `None`):
111
+
A dictionary with additional arguments.
112
+
min_8bit_size (`int`, defaults to 4096):
113
+
The minimum number of elements of the parameter tensors for 8-bit optimization.
114
+
percentile_clipping (`int`, defaults to 100):
115
+
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
116
+
block_wise (`bool`, defaults to `True`):
117
+
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
118
+
"""
65
119
ifnot0.0<=lr:
66
120
raiseValueError(f"Invalid learning rate: {lr}")
67
121
ifnot0.0<=weight_decay:
@@ -105,6 +159,33 @@ def __init__(
105
159
percentile_clipping=100,
106
160
block_wise=True,
107
161
):
162
+
"""
163
+
32-bit Adagrad optimizer.
164
+
165
+
Arguments:
166
+
params (`torch.tensor`):
167
+
The input parameters to optimize.
168
+
lr (`float`, defaults to 1e-2):
169
+
The learning rate.
170
+
lr_decay (`int`, defaults to 0):
171
+
The learning rate decay.
172
+
weight_decay (`float`, defaults to 0.0):
173
+
The weight decay value for the optimizer.
174
+
initial_accumulator_value (`int`, defaults to 0):
175
+
The initial momemtum values.
176
+
eps (`float`, defaults to 1e-10):
177
+
The epsilon value prevents division by zero in the optimizer.
178
+
optim_bits (`int`, defaults to 32):
179
+
The number of bits of the optimizer state.
180
+
args (`dict`, defaults to `None`):
181
+
A dictionary with additional arguments.
182
+
min_8bit_size (`int`, defaults to 4096):
183
+
The minimum number of elements of the parameter tensors for 8-bit optimization.
184
+
percentile_clipping (`int`, defaults to 100):
185
+
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
186
+
block_wise (`bool`, defaults to `True`):
187
+
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
0 commit comments