11import argparse
2- from dataclasses import dataclass
32import logging
3+ from dataclasses import dataclass
44
5- import torch .distributed as dist
65import lightning as L
76import torch
7+ import torch .distributed as dist
88import torch .nn as nn
99import torch .nn .functional as F
10- from lightning .pytorch .demos import Transformer , WikiText2
10+ from lightning .pytorch .demos import WikiText2
1111from lightning .pytorch .strategies import FSDPStrategy , ModelParallelStrategy
1212from torch .distributed .fsdp import BackwardPrefetch , MixedPrecision
1313from torch .utils .data import DataLoader
1414
1515logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
1616log = logging .getLogger (__name__ )
1717
18+
1819@dataclass
1920class Args :
2021 vocab_size : int = 32000
@@ -24,6 +25,7 @@ class Args:
2425 enable_gradient_checkpointing : bool = False
2526 enable_fsdp2 : bool = False
2627
28+
2729class SimpleLayer (nn .Module ):
2830 def __init__ (self , hidden_size ):
2931 super (SimpleLayer , self ).__init__ ()
@@ -37,6 +39,7 @@ def forward(self, x):
3739 x = self .activation (x )
3840 return x
3941
42+
4043class InnerModel (nn .Module ):
4144 def __init__ (self , num_layers , hidden_size , vocab_size = 32000 ):
4245 super (InnerModel , self ).__init__ ()
@@ -46,7 +49,6 @@ def __init__(self, num_layers, hidden_size, vocab_size=32000):
4649 self .layers = nn .ModuleList ([SimpleLayer (hidden_size ) for _ in range (num_layers )])
4750 self .lm_head = nn .Linear (hidden_size , vocab_size )
4851
49-
5052 def forward (self , x ):
5153 x = self .embedding (x )
5254 # Pass the input through each layer sequentially
@@ -66,14 +68,15 @@ def forward(self, *args, **kwargs):
6668
6769
6870class LanguageModel (L .LightningModule ):
69- def __init__ (self ,
70- vocab_size = 32000 ,
71- enable_fp8 = False ,
72- enable_fsdp2 = False ,
73- enable_torch_compile = False ,
74- enable_gradient_checkpointing = False ,
75- enable_cpu_offload = False
76- ):
71+ def __init__ (
72+ self ,
73+ vocab_size = 32000 ,
74+ enable_fp8 = False ,
75+ enable_fsdp2 = False ,
76+ enable_torch_compile = False ,
77+ enable_gradient_checkpointing = False ,
78+ enable_cpu_offload = False ,
79+ ):
7780 super ().__init__ ()
7881 self .model = None
7982 self .vocab_size = vocab_size
@@ -88,10 +91,11 @@ def __init__(self,
8891 } # only used for FP8 training
8992
9093 def log_model_stage (self , stage : str ):
91- """
92- Logs the current state of the model with a description of the stage.
94+ """Logs the current state of the model with a description of the stage.
95+
9396 Args:
9497 stage (str): Description of the current model stage.
98+
9599 """
96100 log .warning (f"Model at stage: { stage } \n { self .model } " )
97101
@@ -129,7 +133,7 @@ def configure_fsdp2(self):
129133
130134 def configure_fp8 (self ):
131135 # Setup fp8 training, if enable_fp8 is false, it will create a fake handler
132- from handlers .fp8_training_handler import FP8Config , Float8TrainingHandler
136+ from handlers .fp8_training_handler import Float8TrainingHandler , FP8Config
133137
134138 fp8_config = FP8Config (
135139 enable_fp8 = self .enable_fp8 ,
@@ -207,13 +211,14 @@ def train(args):
207211 dataset = WikiText2 ()
208212 train_dataloader = DataLoader (dataset , num_workers = 8 , batch_size = 1 )
209213
210- model = LanguageModel (vocab_size = args .vocab_size ,
211- enable_fp8 = args .enable_fp8 ,
212- enable_fsdp2 = args .enable_fsdp2 ,
213- enable_torch_compile = args .enable_torch_compile ,
214- enable_gradient_checkpointing = args .enable_gradient_checkpointing ,
215- enable_cpu_offload = args .enable_cpu_offload ,
216- )
214+ model = LanguageModel (
215+ vocab_size = args .vocab_size ,
216+ enable_fp8 = args .enable_fp8 ,
217+ enable_fsdp2 = args .enable_fsdp2 ,
218+ enable_torch_compile = args .enable_torch_compile ,
219+ enable_gradient_checkpointing = args .enable_gradient_checkpointing ,
220+ enable_cpu_offload = args .enable_cpu_offload ,
221+ )
217222
218223 if args .enable_fsdp2 :
219224 strategy = ModelParallelStrategy (
0 commit comments