1919k = 20
2020
2121def assert_most_approx_close (a , b , rtol = 1e-3 , atol = 1e-3 , max_error_count = 0 ):
22- idx = torch .isclose (a , b , rtol , atol )
22+ idx = torch .isclose (a , b , rtol = rtol , atol = atol )
2323 error_count = (idx == 0 ).sum ().item ()
2424 if error_count > max_error_count :
2525 print (f"Too many values not close: assert { error_count } < { max_error_count } " )
26- torch .testing .assert_close (a , b , rtol , atol )
26+ torch .testing .assert_close (a , b , rtol = rtol , atol = atol )
2727
2828
2929def get_temp_dir ():
@@ -35,13 +35,8 @@ def get_temp_dir():
3535def rm_path (path ):
3636 shutil .rmtree (path )
3737
38- str2bf16support = {}
39- str2bf16support ['adam8bit_blockwise' ] = True
40-
4138str2optimizers = {}
4239str2optimizers ["adam_pytorch" ] = (None , torch .optim .Adam , bnb .optim .Adam )
43- # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
44- # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
4540str2optimizers ["lion_pytorch" ] = (None , Lion , bnb .optim .Lion )
4641str2optimizers ["momentum_pytorch" ] = (
4742 None ,
@@ -51,8 +46,8 @@ def rm_path(path):
5146str2optimizers ["adam" ] = (torch .optim .Adam , bnb .optim .Adam )
5247str2optimizers ["paged_adamw" ] = (torch .optim .AdamW , bnb .optim .PagedAdamW )
5348str2optimizers ["paged_adam" ] = (torch .optim .Adam , bnb .optim .PagedAdam )
54- # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
5549str2optimizers ["lion" ] = (Lion , bnb .optim .Lion )
50+ str2optimizers ["paged_lion" ] = (Lion , bnb .optim .PagedLion )
5651str2optimizers ["momentum" ] = (
5752 lambda pxx : torch .optim .SGD (pxx , 0.01 , 0.9 ),
5853 lambda pxx : bnb .optim .SGD (pxx , 0.01 , 0.9 , block_wise = False ),
@@ -76,6 +71,7 @@ def rm_path(path):
7671str2optimizers ["paged_adamw8bit_blockwise" ] = (torch .optim .AdamW , lambda pxx : bnb .optim .PagedAdamW8bit (pxx , block_wise = True ))
7772str2optimizers ["paged_adam8bit_blockwise" ] = (torch .optim .Adam , lambda pxx : bnb .optim .PagedAdam8bit (pxx , block_wise = True ))
7873str2optimizers ["lion8bit_blockwise" ] = (Lion , lambda pxx : bnb .optim .Lion8bit (pxx , block_wise = True ))
74+ str2optimizers ["paged_lion8bit_blockwise" ] = (Lion , lambda pxx : bnb .optim .PagedLion8bit (pxx , block_wise = True ))
7975str2optimizers ["momentum8bit_blockwise" ] = (
8076 lambda pxx : torch .optim .SGD (pxx , 0.01 , 0.9 ),
8177 lambda pxx : bnb .optim .SGD8bit (pxx , 0.01 , 0.9 , block_wise = True ),
@@ -90,6 +86,7 @@ def rm_path(path):
9086str2statenames ["paged_adamw" ] = [("exp_avg" , "state1" ), ("exp_avg_sq" , "state2" )]
9187str2statenames ["paged_adam" ] = [("exp_avg" , "state1" ), ("exp_avg_sq" , "state2" )]
9288str2statenames ["lion" ] = [("exp_avg" , "state1" )]
89+ str2statenames ["paged_lion" ] = [("exp_avg" , "state1" )]
9390str2statenames ["momentum" ] = [("momentum_buffer" , "state1" )]
9491str2statenames ["lamb" ] = [("exp_avg" , "state1" ), ("exp_avg_sq" , "state2" )]
9592str2statenames ["rmsprop" ] = [("square_avg" , "state1" )]
@@ -104,15 +101,17 @@ def rm_path(path):
104101str2statenames ["rmsprop8bit" ] = [("square_avg" , "state1" , "qmap1" , "max1" )]
105102str2statenames ["rmsprop8bit_blockwise" ] = [("square_avg" , "state1" , "qmap1" , "absmax1" )]
106103str2statenames ["lion8bit_blockwise" ] = [("exp_avg" , "state1" , "qmap1" , "absmax1" )]
104+ str2statenames ["paged_lion8bit_blockwise" ] = [("exp_avg" , "state1" , "qmap1" , "absmax1" )]
107105
108106dim1 = [1024 ]
109107dim2 = [32 , 1024 , 4097 , 1 ]
110- gtype = [torch .float32 , torch .float16 ]
111- optimizer_names = ["adam" , "momentum" , "rmsprop" , 'paged_adamw' , 'paged_adam' , 'lion' ]
108+ gtype = [torch .float32 , torch .float16 , torch . bfloat16 ]
109+ optimizer_names = ["adam" , "momentum" , "rmsprop" , 'paged_adamw' , 'paged_adam' , 'lion' , 'paged_lion' ]
112110values = list (product (dim1 , dim2 , gtype , optimizer_names ))
113111names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}" .format (* vals ) for vals in values ]
114112@pytest .mark .parametrize ("dim1, dim2, gtype, optim_name" , values , ids = names )
115113def test_optimizer32bit (dim1 , dim2 , gtype , optim_name ):
114+ if gtype == torch .bfloat16 and optim_name in ['momentum' , 'rmsprop' ]: pytest .skip ()
116115 if dim1 == 1 and dim2 == 1 :
117116 return
118117 p1 = torch .randn (dim1 , dim2 , device = "cuda" , dtype = gtype ) * 0.1
@@ -254,7 +253,7 @@ def test_global_config(dim1, dim2, gtype):
254253
255254@pytest .mark .parametrize ("dim1, dim2, gtype, optim_name" , values , ids = names )
256255def test_optimizer8bit (dim1 , dim2 , gtype , optim_name ):
257- if gtype == torch .bfloat16 and optim_name not in str2bf16support : return
256+ if gtype == torch .bfloat16 and optim_name not in [ 'adam8bit_blockwise' , 'lion8bit_blockwise' ]: pytest . skip ()
258257 if dim1 == 1 and dim2 == 1 :
259258 return
260259 p1 = torch .randn (dim1 , dim2 , device = "cuda" , dtype = gtype ) * 0.1
@@ -485,7 +484,7 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
485484# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
486485# optimizer_names = ['lamb_apex', 'lamb8bit']
487486# optimizer_names = ['lars_apex', 'lars8bit']
488- optimizer_names = ["adam8bit_blockwise" , 'paged_adam8bit_blockwise' , 'paged_adamw8bit_blockwise' ]
487+ optimizer_names = ["adam8bit_blockwise" , 'paged_adam8bit_blockwise' , 'paged_adamw8bit_blockwise' , 'paged_lion8bit_blockwise' ]
489488values = list (product (dim1 , dim2 , gtype , optimizer_names ))
490489names = [
491490 "dim1_{}_dim2_{}_gtype_{}_optim_{}" .format (* vals ) for vals in values
0 commit comments