Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/recipes/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ transformers
typer
wandb
zstandard
nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv, proj]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [fp8_block_scaling_underflows%]
freq: 1
- tensor: activation
stats: [fp8_block_scaling_scale_inv_min, fp8_block_scaling_scale_inv_max]
freq: 1
- tensor: activation
stats: [scale_inv_max]
freq: 1
- tensor: activation
stats: [fp8_block_scaling_mse]
freq: 1
- tensor: gradient
stats: [fp8_block_scaling_underflows%]
freq: 5
start_step: 0
end_step: 80

19 changes: 19 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv, proj, dense, decoder]

transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 50
- tensor: gradient
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 50
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 50
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv, proj]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1000
- tensor: gradient
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1000
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv, proj, dense, decoder]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1
- tensor: gradient
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 1
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ fp8_config:
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.


# Optimizer config
adamw_kwargs:
lr: 4e-4
Expand All @@ -75,3 +76,7 @@ checkpoint:

logger:
frequency: 100

fp8_stats_config:
fp8_stats_file: null
fp8_log_dir: null
20 changes: 17 additions & 3 deletions bionemo-recipes/recipes/esm2_native_te/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path

import hydra
import nvdlfw_inspect.api as debug_api
import torch
import transformer_engine.pytorch
from omegaconf import DictConfig
Expand All @@ -43,6 +44,13 @@ def main(args: DictConfig) -> float | None:
Returns:
float: The loss value for the final batch.
"""
# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
debug_api.initialize(
config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml",
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
log_dir="./logddp",
default_logging_enabled=True,
)
# Initialize the distributed configuration, including creating the distributed process group.
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
Expand All @@ -65,6 +73,8 @@ def main(args: DictConfig) -> float | None:
if args.use_sequence_packing:
config.attn_input_format = "thd"



# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
# versions of weights are kept.
Expand All @@ -84,13 +94,16 @@ def main(args: DictConfig) -> float | None:
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)

debug_api.infer_and_assign_layer_names(model)

model = model.to(device=device)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[dist_config.local_rank],
output_device=dist_config.local_rank,
device_mesh=device_mesh["ddp"],
)


# If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader.
train_dataloader, dataset_or_sampler = (
Expand All @@ -99,9 +112,9 @@ def main(args: DictConfig) -> float | None:
else create_bshd_dataloader(dist_config, **args.dataset)
)

if args.use_torch_compile:
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
model = torch.compile(model)
# if args.use_torch_compile:
# # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
# model = torch.compile(model)

# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None
Expand Down Expand Up @@ -134,6 +147,7 @@ def main(args: DictConfig) -> float | None:
loss = outputs.loss
loss.backward()

debug_api.step()
# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()

Expand Down
45 changes: 41 additions & 4 deletions bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from contextlib import nullcontext
from pathlib import Path

import os
import hydra
import nvdlfw_inspect.api as debug_api
import torch
import transformer_engine.pytorch
from omegaconf import DictConfig, OmegaConf
Expand All @@ -35,7 +37,7 @@
from distributed_config import DistributedConfig
from perf_logger import PerfLogger
from scheduler import get_linear_schedule_with_warmup

from torch.utils.tensorboard import SummaryWriter

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -55,6 +57,23 @@ def main(args: DictConfig) -> float | None:
torch.distributed.init_process_group(backend="nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)

# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
tb_writer = SummaryWriter('./tensorboard_dir/run1')
fp8_stats_file = args.fp8_stats_config.fp8_stats_file if args.fp8_stats_config.fp8_stats_file else "fp8_stats_mxfp8.yaml"
fp8_log_dir = args.fp8_stats_config.fp8_log_dir if args.fp8_stats_config.fp8_log_dir else "./log_fsdp2_mxfp8"
# Make a subdir for the current rank.
fp8_log_dir = os.path.join(fp8_log_dir, f"rank_{dist_config.local_rank}")
os.makedirs(fp8_log_dir, exist_ok=True)
logger.info(f"Logging FP8 stats to {fp8_log_dir}")
debug_api.initialize(
config_file=fp8_stats_file,
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
log_dir=fp8_log_dir,
default_logging_enabled=True,
tb_writer=tb_writer,
)


# Create a device mesh for FSDP.
device_mesh = init_device_mesh(
"cuda",
Expand Down Expand Up @@ -84,12 +103,27 @@ def main(args: DictConfig) -> float | None:

logger.info("Initialized Model:\n%s", model)


# Debug: Print module types to verify what we're working with
if dist_config.local_rank == 0:
logger.info("=== DEBUG: Module types in model ===")
for name, module in model.named_modules():
if "layernorm_qkv" in name or "proj" in name or "self_attention" in name:
logger.info(f" -----> {name}: {type(module)} <----")
logger.info(
f"=== DEBUG: FP8 config enabled={args.fp8_config.enabled}, recipe={args.fp8_config.fp8_recipe} ==="
)

# We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer

for layer in transformer_stack:
fully_shard(layer, mesh=device_mesh["dp"])
fully_shard(model, mesh=device_mesh["dp"])

# Assign names to layers so debug API can identify them - MUST be done BEFORE FSDP wrapping
debug_api.infer_and_assign_layer_names(model)

# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
Expand All @@ -105,9 +139,9 @@ def main(args: DictConfig) -> float | None:
else create_bshd_dataloader(dist_config, **args.dataset)
)

if args.use_torch_compile:
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
model = torch.compile(model)
# if args.use_torch_compile:
# # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
# model = torch.compile(model)

# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
Expand Down Expand Up @@ -146,6 +180,9 @@ def main(args: DictConfig) -> float | None:
# Step optimizer.
optimizer.step()
scheduler.step()

debug_api.step()

optimizer.zero_grad()

perf_logger.log_step(
Expand Down