Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sample_workloads/lit-gpt-demo/LitGPT.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.c
&& apt-get update -y && apt-get install google-cloud-cli -y

COPY scripts /workspace/scripts
COPY utilities /workspace/pretrain/utilities
COPY openwebtext_trainer.py /workspace/pretrain/
COPY gpt_fsdp.py /workspace/pretrain/

ENTRYPOINT ["/bin/bash", "/workspace/scripts/litgpt_container_entrypoint.sh"]

2 changes: 1 addition & 1 deletion sample_workloads/lit-gpt-demo/build_and_push_litgpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ BASE_SHORT_TAG="${LITGPT_SHA}"
BASE_LONG_TAG="${BASE_IMAGE}:${BASE_SHORT_TAG}"

FULL_SHORT_TAG="${BASE_SHORT_TAG}-${SOME_UUID}"
FULL_LONG_TAG="${FULL_IMAGE}:${FULL_SHORT_TAG}"
FULL_LONG_TAG="${FULL_IMAGE}:latest"

DOCKER_BUILDKIT=1 docker build -f $LITGPT_PATH/Dockerfile -t $BASE_LONG_TAG $LITGPT_PATH

Expand Down
66 changes: 66 additions & 0 deletions sample_workloads/lit-gpt-demo/gpt_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Full definition of a GPT NeoX Language Model, all of it in this single file.

Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""

import math
from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
from typing_extensions import Self

from lit_gpt.config import Config
from lit_gpt.model import Block, GPT


class GPTFSDP(GPT):
def __init__(self, config: Config) -> None:
super().__init__(config)
print ("##### Using gpt_fsdp #####")
# Overload the transformer
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=self.construct_blocks(config),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)

def construct_blocks(self, config):
config.num_blocks_to_combine = int(config.num_blocks_to_combine)
# Don't use multiBlocks
if config.num_blocks_to_combine <= 1:
return nn.ModuleList(Block(config) for _ in range(config.n_layer))

# using multiblocks
num_multi_blocks = config.n_layer // config.num_blocks_to_combine
num_last_blocks = config.n_layer - num_multi_blocks * config.num_blocks_to_combine

print (f"##### num_multi_blocks: {num_multi_blocks}, num_blocks: {num_last_blocks} #####")

h = nn.ModuleList(MultiBlock(config, config.num_blocks_to_combine) for _ in range(num_multi_blocks))
if num_last_blocks > 0:
h.append(MultiBlock(config, num_last_blocks))
return h

class MultiBlock(nn.Module):
def __init__(self, config, num_blocks_to_combine):
super().__init__()
self.blocks = nn.ModuleList(Block(config) for _ in range(num_blocks_to_combine))

def forward(
self,
x,
rope,
max_seq_length,
mask = None,
input_pos = None,
kv_cache = None,
):
for block in self.blocks:
x, _ = block(x, rope, max_seq_length)
return x, kv_cache
5 changes: 5 additions & 0 deletions sample_workloads/lit-gpt-demo/helm/templates/litgpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,14 @@ spec:
value: "{{$root.Values.workload.microBatchSize}}"
- name: MODEL_NAME
value: "{{$root.Values.workload.modelName}}"
- name: NUM_BLOCKS_TO_COMBINE
value: "{{$root.Values.workload.numBlocksToCombine}}"
- name: WARMUP_ITERS
value: "{{$root.Values.workload.warmupIters}}"
- name: MAX_ITERS
value: "{{$root.Values.workload.maxIters}}"
- name: COLLECT_NSYS_PROFILE
value: "{{$root.Values.workload.collectNsysProfile}}"
- name: CLUSTER_TYPE
value: GKE
volumeMounts:
Expand All @@ -182,3 +186,4 @@ spec:
nvidia.com/gpu: !!int 8
---
{{end}}

17 changes: 9 additions & 8 deletions sample_workloads/lit-gpt-demo/helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ network:
rxdmContainer: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.9
disablePmtu: "yes"
workload:
jobTimestamp: # Must be defined
gcsExperimentBucket: # Must be defined
experimentDir: llama2-70b
jobTimestamp: 1 # Must be defined
gcsExperimentBucket: tejas_gpu # Must be defined
experimentDir: llama2-7b-8nodes-bs6-original
gcsDataBucket: litgpt-public-bucket
dataDir: openwebtext_dataset
image: us-docker.pkg.dev/gce-ai-infra/litgpt-full/litgpt
modelName: Llama-2-70b-hf
image: us-central1-docker.pkg.dev/supercomputer-testing/tejasnama-gcr/litgpt-full:latest
modelName: Llama-2-7b-hf
batchSize: 6
microBatchSize: 6
microBatchSize: 4
warmupIters: 10
maxIters: 1000

maxIters: 100
numBlocksToCombine: 4
collectNsysProfile: 'no' # Set to 'yes' for profiles
185 changes: 123 additions & 62 deletions sample_workloads/lit-gpt-demo/openwebtext_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@
from lightning.pytorch.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader, IterableDataset

import torch.multiprocessing as mp
import nvtx

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

mp.set_start_method("spawn", force=True)
import utilities.monitor_collectives
utilities.monitor_collectives.shunt_torch_communication()


from lit_gpt import Config
from lit_gpt.model import GPT, Block
from lit_gpt.speed_monitor import SpeedMonitorCallback, estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, step_csv_logger

from gpt_fsdp import GPTFSDP, MultiBlock

model_name = os.getenv("MODEL_NAME", "Llama-2-70b-hf")
name = "openwebtext"
out_dir = Path(os.getenv("EXPERIMENT_LOCAL_DIR", "")) / "out"
Expand All @@ -32,6 +42,7 @@
eval_iters = 100
log_interval = 1
num_nodes = int(os.getenv("NNODES", "1"))
num_blocks_to_combine = os.getenv("NUM_BLOCKS_TO_COMBINE", 1)

# Hyperparameters
learning_rate = 6e-4
Expand All @@ -57,9 +68,11 @@ def __init__(self, config: Config) -> None:
self.config = config
self.module: Optional[torch.nn.Module] = None
self.measured_flops: Optional[int] = None
self.nsys_profile_step_multiple = 5
self.backward_nvtx_range = None

def configure_model(self) -> None:
self.module = GPT(self.config)
self.module = GPTFSDP(self.config)
self.module.apply(self.module._init_weights)

def configure_optimizers(self) -> torch.optim.Optimizer:
Expand All @@ -69,6 +82,7 @@ def configure_optimizers(self) -> torch.optim.Optimizer:

def on_fit_start(self) -> None:
trainer = self.trainer

with torch.device("meta"):
meta_model = GPT(self.module.config)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
Expand All @@ -88,14 +102,57 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
for optimizer in self.trainer.strategy.optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] = lr

global_batch_idx = batch_idx / gradient_accumulation_steps
if (
global_batch_idx > 0
and global_batch_idx % self.nsys_profile_step_multiple == 0
):
print(f"Starting Nsys profiling")
torch.cuda.cudart().cudaProfilerStart()


def on_train_batch_end(
self, outputs, batch: Any, batch_idx: int, unused: int = 0
) -> None:
global_batch_idx = batch_idx // gradient_accumulation_steps
global_batch_offset = batch_idx % gradient_accumulation_steps
is_last_microbatch = global_batch_offset == gradient_accumulation_steps - 1

if (
global_batch_idx > 1
and global_batch_idx % self.nsys_profile_step_multiple == 0
and is_last_microbatch
):
self.print(f"Stopping Nsys profiling")
torch.cuda.cudart().cudaProfilerStop()
if is_last_microbatch:
self.print(f"HEARTBEAT: {global_batch_idx=}, {batch_idx=}")
self.print(
f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
sys.stdout.flush()
sys.stderr.flush()

@nvtx.annotate(color='green')
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
input_ids, targets = batch
logits = self.module(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
return loss

def on_before_backward(self, loss):
self.backward_nvtx_range = nvtx.start_range(message="backward", color="red")

def on_after_backward(self):
if self.backward_nvtx_range:
nvtx.end_range(self.backward_nvtx_range)

@nvtx.annotate(color='orange')
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
optimizer.step(closure=optimizer_closure)

def validation_step(self, batch: Any, batch_idx: int) -> None:
input_ids, targets = batch
logits = self.module(input_ids)
Expand All @@ -104,68 +161,71 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:


def main(devices: int = 1, precision: Optional[str] = None, tpu: bool = False) -> None:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)

if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
cm = torch.autograd.profiler.emit_nvtx()
with cm:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)

if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
strategy = FSDPStrategy(
auto_wrap_policy={MultiBlock},
activation_checkpointing_policy={MultiBlock},
# the argument is not available in the Trainer strategy, but it's the default anyways
# state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
# the argument is not available in the Trainer strategy, but it's the default anyways
# state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"

logger = step_csv_logger(out_dir, name, cls=CSVLogger, flush_logs_every_n_steps=log_interval)
speed_monitor = SpeedMonitorCallback(
length_fn=lambda batch: batch[0].size(1), batch_size=micro_batch_size, window_size=10, time_unit="seconds"
)
model_checkpoint = ModelCheckpoint(dirpath=out_dir, every_n_train_steps=save_interval, save_last=True, verbose=True)
trainer = L.Trainer(
devices=devices,
strategy=strategy,
precision=precision,
logger=logger,
callbacks=[speed_monitor, model_checkpoint],
max_steps=max_iters,
max_epochs=1,
limit_val_batches=eval_iters,
accumulate_grad_batches=gradient_accumulation_steps,
log_every_n_steps=log_interval,
val_check_interval=eval_interval,
num_nodes=num_nodes
)

L.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP)

trainer.print(hparams)

if trainer.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)

config = Config.from_name(model_name)
trainer.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
model = LightningGPTModule(config)
trainer.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

train_data = Dataset(str(data_dir / "train.bin"), config.block_size)
val_data = Dataset(str(data_dir / "val.bin"), config.block_size)
train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2)
val_dataloader = DataLoader(val_data, batch_size=micro_batch_size, num_workers=2)

t0 = time.perf_counter()
trainer.fit(model, train_dataloader, val_dataloader, ckpt_path="last")
trainer.print(f"Training time: {(time.perf_counter()-t0):.2f}s")
if trainer.strategy.root_device.type == "cuda":
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
strategy = "auto"

logger = step_csv_logger(out_dir, name, cls=CSVLogger, flush_logs_every_n_steps=log_interval)
speed_monitor = SpeedMonitorCallback(
length_fn=lambda batch: batch[0].size(1), batch_size=micro_batch_size, window_size=10, time_unit="seconds"
)
model_checkpoint = ModelCheckpoint(dirpath=out_dir, every_n_train_steps=save_interval, save_last=True, verbose=True)
trainer = L.Trainer(
devices=devices,
strategy=strategy,
precision=precision,
logger=logger,
callbacks=[speed_monitor, model_checkpoint],
max_steps=max_iters,
max_epochs=1,
limit_val_batches=eval_iters,
accumulate_grad_batches=gradient_accumulation_steps,
log_every_n_steps=log_interval,
val_check_interval=eval_interval,
num_nodes=num_nodes
)

L.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP)

trainer.print(hparams)

if trainer.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)

config = Config.from_name(model_name)
config.num_blocks_to_combine = num_blocks_to_combine
trainer.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
model = LightningGPTModule(config)
trainer.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

train_data = Dataset(str(data_dir / "train.bin"), config.block_size)
val_data = Dataset(str(data_dir / "val.bin"), config.block_size)
train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2)
val_dataloader = DataLoader(val_data, batch_size=micro_batch_size, num_workers=2)

t0 = time.perf_counter()
trainer.fit(model, train_dataloader, val_dataloader, ckpt_path="last")
trainer.print(f"Training time: {(time.perf_counter()-t0):.2f}s")
if trainer.strategy.root_device.type == "cuda":
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


class Dataset(IterableDataset):
Expand Down Expand Up @@ -206,3 +266,4 @@ def get_lr(it):
from jsonargparse import CLI

CLI(main)

Loading