-
Notifications
You must be signed in to change notification settings - Fork 3.4k
NeMo2.0 llama3 perf scripts #11702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
NeMo2.0 llama3 perf scripts #11702
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c1e3d3c
perf scripts llama3 8b
malay-nagda cfc0a9c
Apply isort and black reformatting
malay-nagda de0917e
copyright
malay-nagda d072885
llama3 70b
malay-nagda e1ba662
Apply isort and black reformatting
malay-nagda 8031797
405b recipe
malay-nagda e005e8f
Apply isort and black reformatting
malay-nagda cd8dbb2
doc strings
malay-nagda e56b737
Apply isort and black reformatting
malay-nagda 30a27ed
remove tb logging and formatting
malay-nagda 483101b
Apply isort and black reformatting
malay-nagda c823e86
disable default tb and profiling
malay-nagda 0180488
num steps per epoch
malay-nagda fdc174c
Apply isort and black reformatting
malay-nagda 2fca981
Merge branch 'main' into malay/perf_scripts
malay-nagda b9ee569
correct filepaths
malay-nagda 7fd3a46
Apply isort and black reformatting
malay-nagda 3e892cf
remove param
malay-nagda c07da88
README
malay-nagda d7fd063
updated param
malay-nagda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| import os | ||
| from datetime import datetime | ||
|
||
| from typing import Optional | ||
|
|
||
| import nemo_run as run | ||
| from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor | ||
|
|
||
| from nemo.collections.llm.recipes.llama3_8b import pretrain_recipe | ||
| from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed | ||
| from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback | ||
| from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin | ||
|
|
||
|
|
||
| def llama3_8b_performance( | ||
| compute_dtype: str, | ||
| num_nodes: int, | ||
| num_gpus_per_node: int, | ||
| mbs: int, | ||
| gbs: int, | ||
| tp_size: int, | ||
| pp_size: int, | ||
| cp_size: int, | ||
| vp_size: Optional[int], | ||
| max_steps: int, | ||
| ): | ||
| recipe = pretrain_recipe(performance_mode=True) | ||
|
|
||
| # data module configs | ||
| recipe.data.micro_batch_size = mbs | ||
| recipe.data.global_batch_size = gbs | ||
| recipe.data.num_train_samples = max_steps * (num_nodes * num_gpus_per_node) # ensure only 1 epoch for whole run | ||
| recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-8B") | ||
|
|
||
| recipe.trainer.max_steps = max_steps | ||
| recipe.trainer.num_nodes = num_nodes | ||
| recipe.trainer.devices = num_gpus_per_node | ||
|
|
||
| # parallelism configs | ||
| recipe.trainer.strategy.tensor_model_parallel_size = tp_size | ||
| recipe.trainer.strategy.pipeline_model_parallel_size = pp_size | ||
| recipe.trainer.strategy.context_parallel_size = cp_size | ||
| recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size | ||
| if tp_size > 1: | ||
| recipe.trainer.strategy.sequence_parallel = True | ||
| else: | ||
| recipe.trainer.strategy.sequence_parallel = False | ||
|
|
||
| comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) | ||
|
|
||
| # compute dtype configs | ||
| if compute_dtype.lower() == "fp8": | ||
| recipe.trainer.plugins = bf16_with_fp8_mixed() | ||
| recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype | ||
|
|
||
| # callback configs | ||
| garbage_collection_callback = run.Config( | ||
| GarbageCollectionCallback, | ||
| gc_interval_train=100, | ||
| gc_interval_val=500, | ||
| ) | ||
| recipe.trainer.callbacks.extend( | ||
| [ | ||
| garbage_collection_callback, | ||
| ] | ||
| ) | ||
| dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) | ||
| if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: | ||
| if comm_overlap_callback_idx >= 0: | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True | ||
|
|
||
| return recipe | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_cli_args().parse_args() | ||
|
|
||
| num_nodes = 1 | ||
| num_gpus_per_node = 8 | ||
| mbs = 1 | ||
| gbs = 128 | ||
| tp_size = 1 | ||
| pp_size = 1 | ||
| cp_size = 2 | ||
| vp_size = None | ||
| max_steps = 100 | ||
|
|
||
| exp_name = f"llama3_8b_{args.compute_dtype}_{num_nodes}nodes_tp{tp_size}_pp{pp_size}_cp{cp_size}_vp{vp_size}_{mbs}mbs_{gbs}gbs" | ||
|
|
||
| executor = slurm_executor( | ||
| args.account, | ||
| args.partition, | ||
| args.log_dir, | ||
| num_nodes, | ||
| num_gpus_per_node, | ||
| args.time_limit, | ||
| args.container_image, | ||
| custom_mounts=[], | ||
| custom_env_vars={}, | ||
| retries=0, | ||
| ) | ||
|
|
||
| recipe = llama3_8b_performance( | ||
| args.compute_dtype, | ||
| num_nodes, | ||
| num_gpus_per_node, | ||
| mbs, | ||
| gbs, | ||
| tp_size, | ||
| pp_size, | ||
| cp_size, | ||
| vp_size, | ||
| max_steps, | ||
| ) | ||
|
||
|
|
||
| with run.Experiment(exp_name) as exp: | ||
| exp.add( | ||
| recipe, | ||
| executor=executor, | ||
| name=exp_name, | ||
| plugins=[ | ||
| PerfEnvPlugin(enable_vboost=True), | ||
| NsysPlugin(start_step=5, end_step=6), | ||
| ], | ||
| ) | ||
|
|
||
| if not args.dryrun: | ||
| exp.run(sequential=True, detach=True) | ||
| else: | ||
| exp.dryrun() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import argparse | ||
| import os | ||
| from typing import Any, List, Optional | ||
|
||
|
|
||
| import nemo_run as run | ||
| from lightning.pytorch.callbacks.callback import Callback | ||
|
|
||
| from nemo.collections.common.tokenizers.huggingface import AutoTokenizer | ||
| from nemo.collections.llm.recipes.llama3_8b import MegatronCommOverlapCallback | ||
|
|
||
|
|
||
| def slurm_executor( | ||
| account: str, | ||
| partition: str, | ||
| log_dir: str, | ||
| nodes: int, | ||
| num_gpus_per_node: int, | ||
| time_limit: str = "01:00:00", | ||
| container_image: str = "nvcr.io/nvidia/nemo:dev", | ||
| custom_mounts: Optional[list[str]] = None, | ||
| custom_env_vars: Optional[dict[str, str]] = None, | ||
| retries: int = 0, | ||
| ) -> run.SlurmExecutor: | ||
| if not (log_dir and account and partition and nodes and num_gpus_per_node): | ||
| raise RuntimeError( | ||
| "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this ", | ||
| "function.", | ||
| ) | ||
|
|
||
| mounts = [] | ||
| if custom_mounts: | ||
| mounts.extend(custom_mounts) | ||
|
|
||
| env_vars = { | ||
| "TRANSFORMERS_OFFLINE": "1", | ||
| "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", | ||
| "NCCL_NVLS_ENABLE": "0", | ||
| "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", | ||
| "NVTE_ASYNC_AMAX_REDUCTION": "1", | ||
| "NVTE_FUSED_ATTN": "1", | ||
| "NVTE_FLASH_ATTN": "0", | ||
| } | ||
| if custom_env_vars: | ||
| env_vars |= custom_env_vars | ||
|
|
||
| executor = run.SlurmExecutor( | ||
| account=account, | ||
| partition=partition, | ||
| tunnel=run.LocalTunnel( | ||
| job_dir=log_dir, | ||
| ), | ||
| nodes=nodes, | ||
| ntasks_per_node=num_gpus_per_node, | ||
| gpus_per_node=num_gpus_per_node, | ||
| mem="0", | ||
| exclusive=True, | ||
| gres="gpu:8", | ||
| packager=run.GitArchivePackager(), | ||
| ) | ||
|
|
||
| executor.container_image = container_image | ||
| executor.container_mounts = mounts | ||
| executor.env_vars = env_vars | ||
| executor.retries = retries | ||
| executor.time = time_limit | ||
|
|
||
| return executor | ||
|
|
||
|
|
||
| def hf_tokenizer(model_name: str) -> run.Config[AutoTokenizer]: | ||
| """ | ||
| AutoTokenizer first searches for tokenizer files locally in env var 'NEMO_HOME'. | ||
| If tokenizer files are not present locally, AutoTokenizer will try downloading from HuggingFace. | ||
| In the case tokenizer needs downloading, make sure env vars- 'TRANSFORMERS_OFFLINE=0' and | ||
| 'HF_TOKEN:<token_value>' are set inside NeMo container. | ||
| """ | ||
| return run.Config( | ||
| AutoTokenizer, | ||
| pretrained_model_name=model_name, | ||
| use_fast=True, | ||
| ) | ||
|
|
||
|
|
||
| def get_comm_overlap_callback_idx(callbacks: List[Callback]): | ||
| if callbacks: # default is None in lightning | ||
| for idx, callback in enumerate(callbacks): | ||
| if isinstance(callback, MegatronCommOverlapCallback): | ||
| return idx | ||
| return -1 | ||
|
|
||
|
|
||
| def parse_cli_args(): | ||
| parser = argparse.ArgumentParser(description="NeMo2.0 Performance Pretraining and Fine-Tuning") | ||
|
|
||
| parser.add_argument( | ||
| "-a", | ||
| "--account", | ||
| type=str, | ||
| help="Slurm account to use for experiment", | ||
| required=True, | ||
| ) | ||
| parser.add_argument( | ||
| "-p", | ||
| "--partition", | ||
| type=str, | ||
| help="Slurm partition to use for experiment", | ||
| required=True, | ||
| ) | ||
| parser.add_argument( | ||
| "-l", | ||
| "--log_dir", | ||
| type=str, | ||
| help="Directory for logging experiment results. Defaults to '~/nemo_log_dir'", | ||
| required=False, | ||
| default=os.path.expanduser("~/nemo_log_dir"), | ||
| ) | ||
| parser.add_argument( | ||
| "-t", | ||
| "--time_limit", | ||
| type=str, | ||
| help="Maximum time limit to run experiment for. Defaults to 30 minutes (format- 'HH:MM:SS')", | ||
| required=False, | ||
| default="00:30:00", | ||
| ) | ||
| parser.add_argument( | ||
| "-i", | ||
| "--container_image", | ||
| type=str, | ||
| help="NeMo container to use for experiment. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'\ | ||
| Make sure your NGC credentials are accessible in your environment.", | ||
| required=False, | ||
| default="nvcr.io/nvidia/nemo:dev", | ||
| ) | ||
| parser.add_argument( | ||
| "-c", | ||
| "--compute_dtype", | ||
| type=str, | ||
| help="Compute precision. Options- bf16 or fp8. Defaults to bf16", | ||
| required=False, | ||
| default="bf16", | ||
| ) | ||
| parser.add_argument( | ||
| "-d", | ||
| "--dryrun", | ||
| help="If true, prints sbatch script to terminal without launching experiment.", | ||
| required=False, | ||
| action="store_true", | ||
| ) | ||
|
|
||
| return parser | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.