diff --git a/.devcontainer/recipes/requirements.txt b/.devcontainer/recipes/requirements.txt index ed7a41e81f..4a994d1a6b 100644 --- a/.devcontainer/recipes/requirements.txt +++ b/.devcontainer/recipes/requirements.txt @@ -14,3 +14,4 @@ transformers typer wandb zstandard +nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml new file mode 100644 index 0000000000..524342ddd5 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml @@ -0,0 +1,27 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [fp8_block_scaling_underflows%] + freq: 1 + - tensor: activation + stats: [fp8_block_scaling_scale_inv_min, fp8_block_scaling_scale_inv_max] + freq: 1 + - tensor: activation + stats: [scale_inv_max] + freq: 1 + - tensor: activation + stats: [fp8_block_scaling_mse] + freq: 1 + - tensor: gradient + stats: [fp8_block_scaling_underflows%] + freq: 5 + start_step: 0 + end_step: 80 + diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8.yaml new file mode 100644 index 0000000000..133a7445e2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8.yaml @@ -0,0 +1,19 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj, dense, decoder] + + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 50 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 50 + - tensor: weight + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 50 diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_1ksteps.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_1ksteps.yaml new file mode 100644 index 0000000000..97df05d1ae --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_1ksteps.yaml @@ -0,0 +1,18 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1000 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1000 + - tensor: weight + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1000 diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_freq1.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_freq1.yaml new file mode 100644 index 0000000000..ee9eb6c661 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_stats_mxfp8_freq1.yaml @@ -0,0 +1,18 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj, dense, decoder] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1 + - tensor: weight + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 3bc3635ba7..cbaac1f22b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -51,6 +51,7 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. + # Optimizer config adamw_kwargs: lr: 4e-4 @@ -75,3 +76,7 @@ checkpoint: logger: frequency: 100 + +fp8_stats_config: + fp8_stats_file: null + fp8_log_dir: null diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 7c0794f589..f202fe781d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -17,6 +17,7 @@ from pathlib import Path import hydra +import nvdlfw_inspect.api as debug_api import torch import transformer_engine.pytorch from omegaconf import DictConfig @@ -43,6 +44,13 @@ def main(args: DictConfig) -> float | None: Returns: float: The loss value for the final batch. """ + # TE Debug feature logging - MUST be done BEFORE FSDP wrapping + debug_api.initialize( + config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml", + feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"], + log_dir="./logddp", + default_logging_enabled=True, + ) # Initialize the distributed configuration, including creating the distributed process group. dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) @@ -65,6 +73,8 @@ def main(args: DictConfig) -> float | None: if args.use_sequence_packing: config.attn_input_format = "thd" + + # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 # versions of weights are kept. @@ -84,6 +94,8 @@ def main(args: DictConfig) -> float | None: optimizer = AdamW(model.parameters(), **args.adamw_kwargs) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + debug_api.infer_and_assign_layer_names(model) + model = model.to(device=device) model = torch.nn.parallel.DistributedDataParallel( model, @@ -91,6 +103,7 @@ def main(args: DictConfig) -> float | None: output_device=dist_config.local_rank, device_mesh=device_mesh["ddp"], ) + # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. train_dataloader, dataset_or_sampler = ( @@ -99,9 +112,9 @@ def main(args: DictConfig) -> float | None: else create_bshd_dataloader(dist_config, **args.dataset) ) - if args.use_torch_compile: - # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. - model = torch.compile(model) + # if args.use_torch_compile: + # # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + # model = torch.compile(model) # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None @@ -134,6 +147,7 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + debug_api.step() # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 293eafe84d..e46d03757a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -17,7 +17,9 @@ from contextlib import nullcontext from pathlib import Path +import os import hydra +import nvdlfw_inspect.api as debug_api import torch import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf @@ -35,7 +37,7 @@ from distributed_config import DistributedConfig from perf_logger import PerfLogger from scheduler import get_linear_schedule_with_warmup - +from torch.utils.tensorboard import SummaryWriter logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -55,6 +57,23 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) + # TE Debug feature logging - MUST be done BEFORE FSDP wrapping + tb_writer = SummaryWriter('./tensorboard_dir/run1') + fp8_stats_file = args.fp8_stats_config.fp8_stats_file if args.fp8_stats_config.fp8_stats_file else "fp8_stats_mxfp8.yaml" + fp8_log_dir = args.fp8_stats_config.fp8_log_dir if args.fp8_stats_config.fp8_log_dir else "./log_fsdp2_mxfp8" + # Make a subdir for the current rank. + fp8_log_dir = os.path.join(fp8_log_dir, f"rank_{dist_config.local_rank}") + os.makedirs(fp8_log_dir, exist_ok=True) + logger.info(f"Logging FP8 stats to {fp8_log_dir}") + debug_api.initialize( + config_file=fp8_stats_file, + feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"], + log_dir=fp8_log_dir, + default_logging_enabled=True, + tb_writer=tb_writer, + ) + + # Create a device mesh for FSDP. device_mesh = init_device_mesh( "cuda", @@ -84,12 +103,27 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + + # Debug: Print module types to verify what we're working with + if dist_config.local_rank == 0: + logger.info("=== DEBUG: Module types in model ===") + for name, module in model.named_modules(): + if "layernorm_qkv" in name or "proj" in name or "self_attention" in name: + logger.info(f" -----> {name}: {type(module)} <----") + logger.info( + f"=== DEBUG: FP8 config enabled={args.fp8_config.enabled}, recipe={args.fp8_config.fp8_recipe} ===" + ) + # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + for layer in transformer_stack: fully_shard(layer, mesh=device_mesh["dp"]) fully_shard(model, mesh=device_mesh["dp"]) + # Assign names to layers so debug API can identify them - MUST be done BEFORE FSDP wrapping + debug_api.infer_and_assign_layer_names(model) + # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) @@ -105,9 +139,9 @@ def main(args: DictConfig) -> float | None: else create_bshd_dataloader(dist_config, **args.dataset) ) - if args.use_torch_compile: - # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. - model = torch.compile(model) + # if args.use_torch_compile: + # # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + # model = torch.compile(model) # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None @@ -146,6 +180,9 @@ def main(args: DictConfig) -> float | None: # Step optimizer. optimizer.step() scheduler.step() + + debug_api.step() + optimizer.zero_grad() perf_logger.log_step(