Skip to content

Commit c6695f8

Browse files
committed
Minimal transformer examples
1 parent 8ce5287 commit c6695f8

File tree

4 files changed

+200
-0
lines changed

4 files changed

+200
-0
lines changed

examples/fabric/fp8_fsdp2_compile/README.md

Whitespace-only changes.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils.data import DataLoader
5+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
6+
from torch.distributed.device_mesh import DeviceMesh
7+
8+
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
9+
10+
import lightning as L
11+
from lightning.fabric.strategies import ModelParallelStrategy
12+
from lightning.pytorch.demos import Transformer, WikiText2
13+
14+
from tqdm import tqdm
15+
16+
17+
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
18+
float8_config = Float8LinearConfig(
19+
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
20+
pad_inner_dim=True,
21+
)
22+
23+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
24+
# we skip the decoder because it typically vocabulary size
25+
# is not divisible by 16 as required by float8
26+
if fqn == "decoder":
27+
return False
28+
return True
29+
30+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
31+
32+
for module in model.modules():
33+
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
34+
fully_shard(module, mesh=device_mesh)
35+
36+
fully_shard(model, mesh=device_mesh)
37+
38+
model = torch.compile(model)
39+
40+
return model
41+
42+
43+
def train():
44+
L.seed_everything(42)
45+
46+
batch_size = 8
47+
micro_batch_size = 1
48+
49+
dataset = WikiText2()
50+
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)
51+
52+
with torch.device("meta"):
53+
model = Transformer(
54+
vocab_size=dataset.vocab_size,
55+
nlayers=16,
56+
nhid=4096,
57+
ninp=1024,
58+
nhead=32,
59+
)
60+
61+
strategy = ModelParallelStrategy(
62+
data_parallel_size=4,
63+
tensor_parallel_size=1,
64+
parallelize_fn=configure_model
65+
)
66+
67+
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
68+
fabric.launch()
69+
70+
model = fabric.setup(model)
71+
72+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
73+
optimizer = fabric.setup_optimizers(optimizer)
74+
75+
dataloader = fabric.setup_dataloaders(dataloader)
76+
77+
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)
78+
79+
for i, batch in iterable:
80+
input, target = batch
81+
82+
is_accumulating = i % (batch_size // micro_batch_size) != 0
83+
84+
with fabric.no_backward_sync(model, enabled=is_accumulating):
85+
output = model(input, target)
86+
loss = F.nll_loss(output, target.view(-1))
87+
fabric.backward(loss)
88+
89+
if not is_accumulating:
90+
fabric.clip_gradients(model, optimizer, max_norm=1.0)
91+
optimizer.step()
92+
optimizer.zero_grad()
93+
94+
if fabric.is_global_zero:
95+
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")
96+
97+
if i // (batch_size // micro_batch_size) > 100:
98+
break
99+
100+
fabric.print(torch.cuda.memory_summary())
101+
102+
103+
if __name__ == "__main__":
104+
torch.set_float32_matmul_precision('high')
105+
106+
train()

examples/pytorch/fp8_fsdp2_compile/README.md

Whitespace-only changes.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils.data import DataLoader
5+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
6+
7+
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
8+
9+
import lightning as L
10+
from lightning.pytorch.strategies import ModelParallelStrategy
11+
from lightning.pytorch.demos import Transformer, WikiText2
12+
13+
14+
class LanguageModel(L.LightningModule):
15+
def __init__(self, vocab_size):
16+
super().__init__()
17+
self.vocab_size = vocab_size
18+
self.model = None
19+
20+
def configure_model(self):
21+
if self.model is not None:
22+
return
23+
24+
with torch.device("meta"):
25+
model = Transformer(
26+
vocab_size=self.vocab_size,
27+
nlayers=16,
28+
nhid=4096,
29+
ninp=1024,
30+
nhead=32,
31+
)
32+
33+
float8_config = Float8LinearConfig(
34+
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
35+
pad_inner_dim=True,
36+
)
37+
38+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
39+
# we skip the decoder because it typically vocabulary size
40+
# is not divisible by 16 as required by float8
41+
if fqn == "decoder":
42+
return False
43+
return True
44+
45+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
46+
47+
for module in model.modules():
48+
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
49+
fully_shard(module, mesh=self.device_mesh)
50+
51+
fully_shard(model, mesh=self.device_mesh)
52+
53+
self.model = torch.compile(model)
54+
55+
def training_step(self, batch):
56+
input, target = batch
57+
output = self.model(input, target)
58+
loss = F.nll_loss(output, target.view(-1))
59+
self.log("train_loss", loss, prog_bar=True)
60+
return loss
61+
62+
def configure_optimizers(self):
63+
return torch.optim.Adam(self.parameters(), lr=1e-4)
64+
65+
66+
def train():
67+
L.seed_everything(42)
68+
69+
dataset = WikiText2()
70+
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
71+
72+
model = LanguageModel(vocab_size=dataset.vocab_size)
73+
74+
mp_strategy = ModelParallelStrategy(
75+
data_parallel_size=4,
76+
tensor_parallel_size=1,
77+
)
78+
79+
trainer = L.Trainer(
80+
strategy=mp_strategy,
81+
max_steps=100,
82+
precision="bf16-true",
83+
accumulate_grad_batches=8
84+
)
85+
86+
trainer.fit(model, train_dataloader)
87+
88+
trainer.print(torch.cuda.memory_summary())
89+
90+
91+
if __name__ == "__main__":
92+
torch.set_float32_matmul_precision('high')
93+
94+
train()

0 commit comments

Comments
 (0)