-
Notifications
You must be signed in to change notification settings - Fork 3.4k
NeMo2.0 llama3 perf scripts #11702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
NeMo2.0 llama3 perf scripts #11702
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c1e3d3c
perf scripts llama3 8b
malay-nagda cfc0a9c
Apply isort and black reformatting
malay-nagda de0917e
copyright
malay-nagda d072885
llama3 70b
malay-nagda e1ba662
Apply isort and black reformatting
malay-nagda 8031797
405b recipe
malay-nagda e005e8f
Apply isort and black reformatting
malay-nagda cd8dbb2
doc strings
malay-nagda e56b737
Apply isort and black reformatting
malay-nagda 30a27ed
remove tb logging and formatting
malay-nagda 483101b
Apply isort and black reformatting
malay-nagda c823e86
disable default tb and profiling
malay-nagda 0180488
num steps per epoch
malay-nagda fdc174c
Apply isort and black reformatting
malay-nagda 2fca981
Merge branch 'main' into malay/perf_scripts
malay-nagda b9ee569
correct filepaths
malay-nagda 7fd3a46
Apply isort and black reformatting
malay-nagda 3e892cf
remove param
malay-nagda c07da88
README
malay-nagda d7fd063
updated param
malay-nagda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # Performance Recipes | ||
|
|
||
| - Scripts defined in `scripts/llm/performance` are recipes optimized for performance. These scripts can launch pre-training experiments on Slurm based clusters. | ||
| - You will need a virtual environemnt with NeMo and Nemo-Run related dependencies installed as the experiment configuration is resolved before launching it inside NeMo container. | ||
|
|
||
| ## Example | ||
|
|
||
| The following line shows an example of how you can launch a pre-training experiment- | ||
|
|
||
| `python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -partition <your_slurm_partition>` | ||
|
|
||
| ## Configuration Options | ||
|
|
||
| - Slurm account and partition are mandatory arguments for launching the experiment. | ||
| - You can use the following optional arguments as needed- | ||
| - -l/--log_dir: Location to store your experiment artifacts and logs. | ||
| - Make sure the environemnt variable `NEMORUN_HOME=<log_dir>` is accessible and set correctly in your virtual environment. | ||
| - You can run `export NEMORUN_HOME=<log_dir>` in your terminal. You can add it your bashrc file (or equivalent for your OS/Linux distro) for setting it permanently. | ||
| - -t/--time_limit: Maximum time limit for your experiment. Your slurm job will be cancelled after this. Default is 30 minutes. | ||
| - -i/--container_image: The NeMo container you want to use. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'. | ||
| - -c/--compute_dtype: Specifies whether you want to use bf16 or fp8 precision for training. Defaults to 'bf16'. You can choose to use 'fp8'. | ||
| - -ep/--enable_profiling: Enable nsys profiling. It is disabled by default. When enabled, profiling will be enabled for 1 step from step 5 to step 6. You can change the step in the respective recipe script. | ||
| - -tb/--tensorboard: Enable tensorboard logging. It is disabled by default. | ||
| - CAUTION: Tensorboard logging may cause performance overhead. | ||
| - -d/--dryrun: Using this argument will not launch the experiment. It will simply print the sbatch script to stdout. This can be helpful to verify you have set your experiment correctly as needed. | ||
| - You don't need to set any value for `--enable_profiling`, `--tensorboard` and `--dryrun`. See the below example for reference- | ||
| `python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -p <your_slurm_partition> -ep --tensorboard -d` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import nemo_run as run | ||
| from nemo_run.config import NEMORUN_HOME | ||
| from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor | ||
|
|
||
| from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe | ||
| from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed | ||
| from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback | ||
| from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin | ||
| from nemo.utils import logging | ||
|
|
||
| NUM_NODES = 72 | ||
| NUM_GPUS_PER_NODE = 8 | ||
| MICRO_BATCH_SIZE = 1 | ||
| GLOBAL_BATCH_SIZE = 252 | ||
| TP_SIZE = 8 | ||
| PP_SIZE = 9 | ||
| CP_SIZE = 2 | ||
| VP_SIZE = 7 | ||
| MAX_STEPS = 100 | ||
|
|
||
|
|
||
| def llama3_405b_performance_recipe( | ||
| compute_dtype: str, | ||
| num_nodes: int, | ||
| num_gpus_per_node: int, | ||
| mbs: int, | ||
| gbs: int, | ||
| tp_size: int, | ||
| pp_size: int, | ||
| cp_size: int, | ||
| vp_size: Optional[int], | ||
| max_steps: int, | ||
| ): | ||
| """ | ||
| llama3 405b pre-train recipe aimed at achieving best possible performance. | ||
|
|
||
| NOTE: Use fp8 precision training with caution. It might not give desirable results. | ||
| """ | ||
| recipe = pretrain_recipe(performance_mode=True) | ||
|
|
||
| # data module configs | ||
| recipe.data.micro_batch_size = mbs | ||
| recipe.data.global_batch_size = gbs | ||
| recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run | ||
| recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B") | ||
|
|
||
| recipe.trainer.max_steps = max_steps | ||
| recipe.trainer.num_nodes = num_nodes | ||
| recipe.trainer.devices = num_gpus_per_node | ||
|
|
||
| # parallelism configs | ||
| recipe.trainer.strategy.tensor_model_parallel_size = tp_size | ||
| recipe.trainer.strategy.pipeline_model_parallel_size = pp_size | ||
| recipe.trainer.strategy.context_parallel_size = cp_size | ||
| recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size | ||
| if tp_size > 1: | ||
| recipe.trainer.strategy.sequence_parallel = True | ||
| else: | ||
| recipe.trainer.strategy.sequence_parallel = False | ||
|
|
||
| comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) | ||
|
|
||
| # compute dtype configs | ||
| if compute_dtype.lower() == "fp8": | ||
| recipe.trainer.plugins = bf16_with_fp8_mixed() | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True | ||
|
|
||
| recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype | ||
|
|
||
| # callback configs | ||
| garbage_collection_callback = run.Config( | ||
| GarbageCollectionCallback, | ||
| gc_interval_train=100, | ||
| gc_interval_val=500, | ||
| ) | ||
| recipe.trainer.callbacks.extend( | ||
| [ | ||
| garbage_collection_callback, | ||
| ] | ||
| ) | ||
| dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) | ||
| if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: | ||
| if comm_overlap_callback_idx >= 0: | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True | ||
|
|
||
| # Misc. for overall faster experiment runtime | ||
| recipe.log.ckpt = None | ||
| recipe.trainer.enable_checkpointing = False | ||
| recipe.trainer.val_check_interval = max_steps * gbs / dp_size | ||
| recipe.trainer.log_every_n_steps = 1 | ||
|
|
||
| return recipe | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_cli_args().parse_args() | ||
| if args.log_dir != NEMORUN_HOME: | ||
| import sys | ||
|
|
||
| logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.") | ||
| sys.exit(1) | ||
|
|
||
| exp_name = "_".join( | ||
| [ | ||
| f"llama3_405b", | ||
| args.compute_dtype, | ||
| f"{NUM_NODES}nodes", | ||
| f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", | ||
| f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", | ||
| ] | ||
| ) | ||
|
|
||
| executor = slurm_executor( | ||
| args.account, | ||
| args.partition, | ||
| args.log_dir, | ||
| NUM_NODES, | ||
| NUM_GPUS_PER_NODE, | ||
| args.time_limit, | ||
| args.container_image, | ||
| custom_mounts=[], | ||
| custom_env_vars={}, | ||
| retries=0, | ||
| ) | ||
|
|
||
| recipe = llama3_405b_performance_recipe( | ||
| args.compute_dtype, | ||
| NUM_NODES, | ||
| NUM_GPUS_PER_NODE, | ||
| MICRO_BATCH_SIZE, | ||
| GLOBAL_BATCH_SIZE, | ||
| TP_SIZE, | ||
| PP_SIZE, | ||
| CP_SIZE, | ||
| VP_SIZE, | ||
| MAX_STEPS, | ||
| ) | ||
|
|
||
| if not args.tensorboard: # tensorboard adds performance overhead. | ||
| recipe.log.tensorboard = None | ||
| recipe.trainer.logger = False | ||
| else: | ||
| # default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>` | ||
| # following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>` | ||
| recipe.log.log_dir = "/nemo_run/lightning_logs" | ||
|
|
||
| plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)] | ||
| if args.enable_profiling: | ||
| plugins.append(NsysPlugin(start_step=5, end_step=6)) | ||
|
|
||
| with run.Experiment(exp_name) as exp: | ||
| exp.add( | ||
| recipe, | ||
| executor=executor, | ||
| name=exp_name, | ||
| plugins=plugins, | ||
| ) | ||
|
|
||
| if not args.dryrun: | ||
| exp.run(sequential=True, detach=True) | ||
| else: | ||
| exp.dryrun() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import nemo_run as run | ||
| from nemo_run.config import NEMORUN_HOME | ||
| from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor | ||
|
|
||
| from nemo.collections.llm.recipes.llama3_70b import pretrain_recipe | ||
| from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed | ||
| from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback | ||
| from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin | ||
| from nemo.utils import logging | ||
|
|
||
| NUM_NODES = 8 | ||
| NUM_GPUS_PER_NODE = 8 | ||
| MICRO_BATCH_SIZE = 1 | ||
| GLOBAL_BATCH_SIZE = 128 | ||
| TP_SIZE = 4 | ||
| PP_SIZE = 4 | ||
| CP_SIZE = 2 | ||
| VP_SIZE = 5 | ||
| MAX_STEPS = 100 | ||
|
|
||
|
|
||
| def llama3_70b_performance_recipe( | ||
| compute_dtype: str, | ||
| num_nodes: int, | ||
| num_gpus_per_node: int, | ||
| mbs: int, | ||
| gbs: int, | ||
| tp_size: int, | ||
| pp_size: int, | ||
| cp_size: int, | ||
| vp_size: Optional[int], | ||
| max_steps: int, | ||
| ): | ||
| """ | ||
| llama3 70b pre-train recipe aimed at achieving best possible performance. | ||
|
|
||
| NOTE: Use fp8 precision training with caution. It might not give desirable results. | ||
| """ | ||
| recipe = pretrain_recipe(performance_mode=True) | ||
|
|
||
| # data module configs | ||
| recipe.data.micro_batch_size = mbs | ||
| recipe.data.global_batch_size = gbs | ||
| recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run | ||
| recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-70B") | ||
|
|
||
| recipe.trainer.max_steps = max_steps | ||
| recipe.trainer.num_nodes = num_nodes | ||
| recipe.trainer.devices = num_gpus_per_node | ||
|
|
||
| # parallelism configs | ||
| recipe.trainer.strategy.tensor_model_parallel_size = tp_size | ||
| recipe.trainer.strategy.pipeline_model_parallel_size = pp_size | ||
| recipe.trainer.strategy.context_parallel_size = cp_size | ||
| recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size | ||
| if tp_size > 1: | ||
| recipe.trainer.strategy.sequence_parallel = True | ||
| else: | ||
| recipe.trainer.strategy.sequence_parallel = False | ||
|
|
||
| comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) | ||
|
|
||
| # compute dtype configs | ||
| if compute_dtype.lower() == "fp8": | ||
| recipe.trainer.plugins = bf16_with_fp8_mixed() | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True | ||
|
|
||
| recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype | ||
|
|
||
| # callback configs | ||
| garbage_collection_callback = run.Config( | ||
| GarbageCollectionCallback, | ||
| gc_interval_train=100, | ||
| gc_interval_val=500, | ||
| ) | ||
| recipe.trainer.callbacks.extend( | ||
| [ | ||
| garbage_collection_callback, | ||
| ] | ||
| ) | ||
| dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) | ||
| if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: | ||
| if comm_overlap_callback_idx >= 0: | ||
| recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True | ||
|
|
||
| # Misc. for overall faster experiment runtime | ||
| recipe.log.ckpt = None | ||
| recipe.trainer.enable_checkpointing = False | ||
| recipe.trainer.val_check_interval = max_steps * gbs / dp_size | ||
| recipe.trainer.log_every_n_steps = 1 | ||
|
|
||
| return recipe | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_cli_args().parse_args() | ||
| if args.log_dir != NEMORUN_HOME: | ||
| import sys | ||
|
|
||
| logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.") | ||
| sys.exit(1) | ||
|
|
||
| exp_name = "_".join( | ||
| [ | ||
| f"llama3_70b", | ||
| args.compute_dtype, | ||
| f"{NUM_NODES}nodes", | ||
| f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", | ||
| f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", | ||
| ] | ||
| ) | ||
|
|
||
| executor = slurm_executor( | ||
| args.account, | ||
| args.partition, | ||
| args.log_dir, | ||
| NUM_NODES, | ||
| NUM_GPUS_PER_NODE, | ||
| args.time_limit, | ||
| args.container_image, | ||
| custom_mounts=[], | ||
| custom_env_vars={}, | ||
| retries=0, | ||
| ) | ||
|
|
||
| recipe = llama3_70b_performance_recipe( | ||
| args.compute_dtype, | ||
| NUM_NODES, | ||
| NUM_GPUS_PER_NODE, | ||
| MICRO_BATCH_SIZE, | ||
| GLOBAL_BATCH_SIZE, | ||
| TP_SIZE, | ||
| PP_SIZE, | ||
| CP_SIZE, | ||
| VP_SIZE, | ||
| MAX_STEPS, | ||
| ) | ||
|
||
|
|
||
| if not args.tensorboard: # tensorboard adds performance overhead. | ||
| recipe.log.tensorboard = None | ||
| recipe.trainer.logger = False | ||
| else: | ||
| # default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>` | ||
| # following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>` | ||
| recipe.log.log_dir = "/nemo_run/lightning_logs" | ||
|
|
||
| plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)] | ||
| if args.enable_profiling: | ||
| plugins.append(NsysPlugin(start_step=5, end_step=6)) | ||
|
|
||
| with run.Experiment(exp_name) as exp: | ||
| exp.add( | ||
| recipe, | ||
| executor=executor, | ||
| name=exp_name, | ||
| plugins=plugins, | ||
| ) | ||
|
|
||
| if not args.dryrun: | ||
| exp.run(sequential=True, detach=True) | ||
| else: | ||
| exp.dryrun() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.