File tree Expand file tree Collapse file tree 2 files changed +18
-1
lines changed Expand file tree Collapse file tree 2 files changed +18
-1
lines changed Original file line number Diff line number Diff line change 99
1010from ema_pytorch import EMA
1111
12+ from lightning import Fabric
1213
1314# helpers
1415
@@ -49,9 +50,21 @@ def __init__(
4950 ),
5051 clip_grad_norm = 10. ,
5152 default_lambda_lr = default_lambda_lr_fn ,
53+ fabric : Fabric | None = None ,
54+ accelerator = 'auto' ,
55+ fabric_kwargs : dict = dict (),
5256 ema_kwargs : dict = dict ()
5357 ):
5458 super ().__init__ ()
59+
60+ if not exists (fabric ):
61+ fabric = Fabric (accelerator = accelerator , ** fabric_kwargs )
62+
63+ self .fabric = fabric
64+ fabric .launch ()
65+
66+ # model
67+
5568 self .model = model
5669
5770 # exponential moving average
@@ -74,6 +87,10 @@ def __init__(
7487
7588 self .optimizer = optimizer
7689
90+ # setup fabric
91+
92+ self .model , self .optimizer = fabric .setup (self .model , self .optimizer )
93+
7794 # scheduler
7895
7996 if not exists (scheduler ):
Original file line number Diff line number Diff line change @@ -29,7 +29,7 @@ dependencies = [
2929 " ema-pytorch>=0.4.8" ,
3030 " environs" ,
3131 " jaxtyping>=0.2.28" ,
32- " pytorch- lightning>=2.2.5" ,
32+ " lightning>=2.2.5" ,
3333 " taylor-series-linear-attention>=0.1.9" ,
3434 " torch>=2.1" ,
3535 " tqdm" ,
You can’t perform that action at this time.
0 commit comments