Skip to content

Commit 1b8772a

Browse files
committed
Added PagedLion and bf16 Lion.
1 parent 2bce175 commit 1b8772a

File tree

7 files changed

+46
-97
lines changed

7 files changed

+46
-97
lines changed

bitsandbytes/functional.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ def prod(iterable):
3737
lib.crmsprop32bit_grad_32,
3838
lib.crmsprop32bit_grad_16,
3939
)
40-
str2optimizer32bit["lion"] = (
41-
lib.clion32bit_grad_32,
42-
lib.clion32bit_grad_16,
43-
)
40+
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
4441
str2optimizer32bit["adagrad"] = (
4542
lib.cadagrad32bit_grad_32,
4643
lib.cadagrad32bit_grad_16,
@@ -89,6 +86,7 @@ def prod(iterable):
8986
str2optimizer8bit_blockwise["lion"] = (
9087
lib.clion_8bit_blockwise_grad_fp32,
9188
lib.clion_8bit_blockwise_grad_fp16,
89+
lib.clion_8bit_blockwise_grad_bf16,
9290
)
9391
str2optimizer8bit_blockwise["adagrad"] = (
9492
lib.cadagrad_8bit_blockwise_grad_fp32,

bitsandbytes/optim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
1313
from .optimizer import GlobalOptimManager
1414
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
15-
from .lion import Lion, Lion8bit, Lion32bit
15+
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
1616
from .sgd import SGD, SGD8bit, SGD32bit

bitsandbytes/optim/lion.py

Lines changed: 19 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,84 +4,27 @@
44
# LICENSE file in the root directory of this source tree.
55
from bitsandbytes.optim.optimizer import Optimizer1State
66

7-
87
class Lion(Optimizer1State):
9-
def __init__(
10-
self,
11-
params,
12-
lr=1e-4,
13-
betas=(0.9, 0.99),
14-
weight_decay=0,
15-
optim_bits=32,
16-
args=None,
17-
min_8bit_size=4096,
18-
percentile_clipping=100,
19-
block_wise=True,
20-
):
21-
super().__init__(
22-
"lion",
23-
params,
24-
lr,
25-
betas,
26-
0.,
27-
weight_decay,
28-
optim_bits,
29-
args,
30-
min_8bit_size,
31-
percentile_clipping,
32-
block_wise,
33-
)
34-
8+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
9+
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
3510

3611
class Lion8bit(Optimizer1State):
37-
def __init__(
38-
self,
39-
params,
40-
lr=1e-4,
41-
betas=(0.9, 0.99),
42-
weight_decay=0,
43-
args=None,
44-
min_8bit_size=4096,
45-
percentile_clipping=100,
46-
block_wise=True,
47-
):
48-
super().__init__(
49-
"lion",
50-
params,
51-
lr,
52-
betas,
53-
0.,
54-
weight_decay,
55-
8,
56-
args,
57-
min_8bit_size,
58-
percentile_clipping,
59-
block_wise,
60-
)
61-
12+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
13+
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
6214

6315
class Lion32bit(Optimizer1State):
64-
def __init__(
65-
self,
66-
params,
67-
lr=1e-4,
68-
betas=(0.9, 0.99),
69-
weight_decay=0,
70-
args=None,
71-
min_8bit_size=4096,
72-
percentile_clipping=100,
73-
block_wise=True,
74-
):
75-
super().__init__(
76-
"lion",
77-
params,
78-
lr,
79-
betas,
80-
0.,
81-
weight_decay,
82-
32,
83-
args,
84-
min_8bit_size,
85-
percentile_clipping,
86-
block_wise,
87-
)
16+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
17+
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
18+
19+
20+
class PagedLion(Optimizer1State):
21+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
22+
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
23+
24+
class PagedLion8bit(Optimizer1State):
25+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
26+
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
27+
28+
class PagedLion32bit(Optimizer1State):
29+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
30+
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)

csrc/kernels.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
36663666
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
36673667
MAKE_PreconditionOptimizer32bit1State(LION, half)
36683668
MAKE_PreconditionOptimizer32bit1State(LION, float)
3669+
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
36693670
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
36703671
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
36713672

@@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
36793680
MAKE_Optimizer32bit1State(RMSPROP, float)
36803681
MAKE_Optimizer32bit1State(LION, half)
36813682
MAKE_Optimizer32bit1State(LION, float)
3683+
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
36823684
MAKE_Optimizer32bit1State(ADAGRAD, half)
36833685
MAKE_Optimizer32bit1State(ADAGRAD, float)
36843686

@@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
38523854
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
38533855
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
38543856
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
3857+
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
38553858
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
38563859
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)

csrc/ops.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half)
802802
MAKE_optimizer32bit(RMSPROP, float)
803803
MAKE_optimizer32bit(LION, half)
804804
MAKE_optimizer32bit(LION, float)
805+
MAKE_optimizer32bit(LION, __nv_bfloat16)
805806
MAKE_optimizer32bit(ADAGRAD, half)
806807
MAKE_optimizer32bit(ADAGRAD, float)
807808

@@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
837838
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
838839
MAKE_optimizerStatic8bitBlockwise(half, LION);
839840
MAKE_optimizerStatic8bitBlockwise(float, LION);
841+
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
840842
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
841843
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
842844

csrc/pythonInterface.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
5151
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
5252
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
5353
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
54-
MAKE_FUNC32(lion, LION, float, 32)
55-
MAKE_FUNC32(lion, LION, half, 16)
54+
MAKE_FUNC32(lion, LION, float, fp32)
55+
MAKE_FUNC32(lion, LION, half, fp16)
56+
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
5657
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
5758
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
5859

@@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
9596
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
9697
MAKE_BLOCKWISE8(lion, LION, half, fp16)
9798
MAKE_BLOCKWISE8(lion, LION, float, fp32)
99+
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
98100

99101

100102
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
@@ -201,8 +203,9 @@ extern "C"
201203
MAKE_CFUNC32(momentum, half, 16)
202204
MAKE_CFUNC32(rmsprop, float, 32)
203205
MAKE_CFUNC32(rmsprop, half, 16)
204-
MAKE_CFUNC32(lion, float, 32)
205-
MAKE_CFUNC32(lion, half, 16)
206+
MAKE_CFUNC32(lion, float, fp32)
207+
MAKE_CFUNC32(lion, half, fp16)
208+
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
206209
MAKE_CFUNC32(adagrad, float, 32)
207210
MAKE_CFUNC32(adagrad, half, 16)
208211

@@ -245,6 +248,7 @@ extern "C"
245248
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
246249
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
247250
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
251+
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
248252

249253
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
250254
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }

tests/test_optim.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
k = 20
2020

2121
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
22-
idx = torch.isclose(a, b, rtol, atol)
22+
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
2323
error_count = (idx == 0).sum().item()
2424
if error_count > max_error_count:
2525
print(f"Too many values not close: assert {error_count} < {max_error_count}")
26-
torch.testing.assert_close(a, b, rtol, atol)
26+
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
2727

2828

2929
def get_temp_dir():
@@ -35,13 +35,8 @@ def get_temp_dir():
3535
def rm_path(path):
3636
shutil.rmtree(path)
3737

38-
str2bf16support = {}
39-
str2bf16support['adam8bit_blockwise'] = True
40-
4138
str2optimizers = {}
4239
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
43-
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
44-
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
4540
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
4641
str2optimizers["momentum_pytorch"] = (
4742
None,
@@ -51,8 +46,8 @@ def rm_path(path):
5146
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
5247
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
5348
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
54-
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
5549
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
50+
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
5651
str2optimizers["momentum"] = (
5752
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
5853
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@@ -76,6 +71,7 @@ def rm_path(path):
7671
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
7772
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
7873
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
74+
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
7975
str2optimizers["momentum8bit_blockwise"] = (
8076
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
8177
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@@ -90,6 +86,7 @@ def rm_path(path):
9086
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
9187
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
9288
str2statenames["lion"] = [("exp_avg", "state1")]
89+
str2statenames["paged_lion"] = [("exp_avg", "state1")]
9390
str2statenames["momentum"] = [("momentum_buffer", "state1")]
9491
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
9592
str2statenames["rmsprop"] = [("square_avg", "state1")]
@@ -104,15 +101,17 @@ def rm_path(path):
104101
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
105102
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
106103
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
104+
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
107105

108106
dim1 = [1024]
109107
dim2 = [32, 1024, 4097, 1]
110-
gtype = [torch.float32, torch.float16]
111-
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion']
108+
gtype = [torch.float32, torch.float16, torch.bfloat16]
109+
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
112110
values = list(product(dim1, dim2, gtype, optimizer_names))
113111
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
114112
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
115113
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
114+
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
116115
if dim1 == 1 and dim2 == 1:
117116
return
118117
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
@@ -254,7 +253,7 @@ def test_global_config(dim1, dim2, gtype):
254253

255254
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
256255
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
257-
if gtype == torch.bfloat16 and optim_name not in str2bf16support: return
256+
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
258257
if dim1 == 1 and dim2 == 1:
259258
return
260259
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
@@ -485,7 +484,7 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
485484
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
486485
# optimizer_names = ['lamb_apex', 'lamb8bit']
487486
# optimizer_names = ['lars_apex', 'lars8bit']
488-
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
487+
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
489488
values = list(product(dim1, dim2, gtype, optimizer_names))
490489
names = [
491490
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values

0 commit comments

Comments
 (0)