4040from torch .utils .data import Sampler , Dataset , DataLoader as OrigDataLoader
4141from torch .optim .lr_scheduler import LambdaLR , LRScheduler
4242
43+ from lion_pytorch .foreach import Lion
4344from adam_atan2_pytorch .foreach import AdamAtan2
4445
4546from ema_pytorch import EMA
@@ -60,6 +61,9 @@ def default(v, d):
6061def divisible_by (num , den ):
6162 return (num % den ) == 0
6263
64+ def at_most_one_of (* flags : bool ) -> bool :
65+ assert sum ([* map (int , flags )]) <= 1
66+
6367def cycle (dataloader : DataLoader ):
6468 while True :
6569 for batch in dataloader :
@@ -281,6 +285,7 @@ def __init__(
281285 use_foreach = True
282286 ),
283287 use_adam_atan2 : bool = False ,
288+ use_lion : bool = False ,
284289 use_torch_compile : bool = False
285290 ):
286291 super ().__init__ ()
@@ -325,13 +330,18 @@ def __init__(
325330 # optimizer
326331
327332 if not exists (optimizer ):
328- adam_klass = Adam
333+ optimizer_klass = Adam
334+
335+ assert at_most_one_of (use_adam_atan2 , use_lion )
329336
330337 if use_adam_atan2 :
331338 default_adam_kwargs .pop ('eps' , None )
332- adam_klass = AdamAtan2
339+ optimizer_klass = AdamAtan2
340+ elif use_lion :
341+ default_adam_kwargs .pop ('eps' , None )
342+ optimizer_klass = Lion
333343
334- optimizer = adam_klass (
344+ optimizer = optimizer_klass (
335345 model .parameters (),
336346 lr = lr ,
337347 ** default_adam_kwargs
0 commit comments