Skip to content

Commit bb513c9

Browse files
malay-nagdaabhinavg4
authored andcommitted
NeMo2.0 llama3 perf scripts (#11702)
* perf scripts llama3 8b Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * copyright Signed-off-by: Malay Nagda <malayn@nvidia.com> * llama3 70b Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * 405b recipe Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * doc strings Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * remove tb logging and formatting Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * disable default tb and profiling Signed-off-by: Malay Nagda <malayn@nvidia.com> * num steps per epoch Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * correct filepaths Signed-off-by: Malay Nagda <malayn@nvidia.com> * Apply isort and black reformatting Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> * remove param Signed-off-by: Malay Nagda <malayn@nvidia.com> * README Signed-off-by: Malay Nagda <malayn@nvidia.com> * updated param Signed-off-by: Malay Nagda <malayn@nvidia.com> --------- Signed-off-by: Malay Nagda <malayn@nvidia.com> Signed-off-by: malay-nagda <malay-nagda@users.noreply.github.com> Co-authored-by: malay-nagda <malay-nagda@users.noreply.github.com> Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
1 parent b408d68 commit bb513c9

File tree

6 files changed

+759
-1
lines changed

6 files changed

+759
-1
lines changed

nemo/lightning/run/plugins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class PerfEnvPlugin(run.Plugin):
314314
enable_layernorm_sm_margin: bool = True
315315
layernorm_sm_margin: int = 16
316316
enable_vboost: bool = False
317-
nccl_pp_comm_chunksize: int = None
317+
nccl_pp_comm_chunksize: Optional[int] = None
318318

319319
def get_vboost_srun_cmd(self, nodes, job_dir):
320320
"Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command"

scripts/llm/performance/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Performance Recipes
2+
3+
- Scripts defined in `scripts/llm/performance` are recipes optimized for performance. These scripts can launch pre-training experiments on Slurm based clusters.
4+
- You will need a virtual environemnt with NeMo and Nemo-Run related dependencies installed as the experiment configuration is resolved before launching it inside NeMo container.
5+
6+
## Example
7+
8+
The following line shows an example of how you can launch a pre-training experiment-
9+
10+
`python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -partition <your_slurm_partition>`
11+
12+
## Configuration Options
13+
14+
- Slurm account and partition are mandatory arguments for launching the experiment.
15+
- You can use the following optional arguments as needed-
16+
- -l/--log_dir: Location to store your experiment artifacts and logs.
17+
- Make sure the environemnt variable `NEMORUN_HOME=<log_dir>` is accessible and set correctly in your virtual environment.
18+
- You can run `export NEMORUN_HOME=<log_dir>` in your terminal. You can add it your bashrc file (or equivalent for your OS/Linux distro) for setting it permanently.
19+
- -t/--time_limit: Maximum time limit for your experiment. Your slurm job will be cancelled after this. Default is 30 minutes.
20+
- -i/--container_image: The NeMo container you want to use. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'.
21+
- -c/--compute_dtype: Specifies whether you want to use bf16 or fp8 precision for training. Defaults to 'bf16'. You can choose to use 'fp8'.
22+
- -ep/--enable_profiling: Enable nsys profiling. It is disabled by default. When enabled, profiling will be enabled for 1 step from step 5 to step 6. You can change the step in the respective recipe script.
23+
- -tb/--tensorboard: Enable tensorboard logging. It is disabled by default.
24+
- CAUTION: Tensorboard logging may cause performance overhead.
25+
- -d/--dryrun: Using this argument will not launch the experiment. It will simply print the sbatch script to stdout. This can be helpful to verify you have set your experiment correctly as needed.
26+
- You don't need to set any value for `--enable_profiling`, `--tensorboard` and `--dryrun`. See the below example for reference-
27+
`python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -p <your_slurm_partition> -ep --tensorboard -d`
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import nemo_run as run
18+
from nemo_run.config import NEMORUN_HOME
19+
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor
20+
21+
from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe
22+
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
23+
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
24+
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
25+
from nemo.utils import logging
26+
27+
NUM_NODES = 72
28+
NUM_GPUS_PER_NODE = 8
29+
MICRO_BATCH_SIZE = 1
30+
GLOBAL_BATCH_SIZE = 252
31+
TP_SIZE = 8
32+
PP_SIZE = 9
33+
CP_SIZE = 2
34+
VP_SIZE = 7
35+
MAX_STEPS = 100
36+
37+
38+
def llama3_405b_performance_recipe(
39+
compute_dtype: str,
40+
num_nodes: int,
41+
num_gpus_per_node: int,
42+
mbs: int,
43+
gbs: int,
44+
tp_size: int,
45+
pp_size: int,
46+
cp_size: int,
47+
vp_size: Optional[int],
48+
max_steps: int,
49+
):
50+
"""
51+
llama3 405b pre-train recipe aimed at achieving best possible performance.
52+
53+
NOTE: Use fp8 precision training with caution. It might not give desirable results.
54+
"""
55+
recipe = pretrain_recipe(performance_mode=True)
56+
57+
# data module configs
58+
recipe.data.micro_batch_size = mbs
59+
recipe.data.global_batch_size = gbs
60+
recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run
61+
recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B")
62+
63+
recipe.trainer.max_steps = max_steps
64+
recipe.trainer.num_nodes = num_nodes
65+
recipe.trainer.devices = num_gpus_per_node
66+
67+
# parallelism configs
68+
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
69+
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
70+
recipe.trainer.strategy.context_parallel_size = cp_size
71+
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
72+
if tp_size > 1:
73+
recipe.trainer.strategy.sequence_parallel = True
74+
else:
75+
recipe.trainer.strategy.sequence_parallel = False
76+
77+
comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)
78+
79+
# compute dtype configs
80+
if compute_dtype.lower() == "fp8":
81+
recipe.trainer.plugins = bf16_with_fp8_mixed()
82+
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True
83+
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True
84+
85+
recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype
86+
87+
# callback configs
88+
garbage_collection_callback = run.Config(
89+
GarbageCollectionCallback,
90+
gc_interval_train=100,
91+
gc_interval_val=500,
92+
)
93+
recipe.trainer.callbacks.extend(
94+
[
95+
garbage_collection_callback,
96+
]
97+
)
98+
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
99+
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
100+
if comm_overlap_callback_idx >= 0:
101+
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True
102+
103+
# Misc. for overall faster experiment runtime
104+
recipe.log.ckpt = None
105+
recipe.trainer.enable_checkpointing = False
106+
recipe.trainer.val_check_interval = max_steps * gbs / dp_size
107+
recipe.trainer.log_every_n_steps = 1
108+
109+
return recipe
110+
111+
112+
if __name__ == "__main__":
113+
args = parse_cli_args().parse_args()
114+
if args.log_dir != NEMORUN_HOME:
115+
import sys
116+
117+
logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
118+
sys.exit(1)
119+
120+
exp_name = "_".join(
121+
[
122+
f"llama3_405b",
123+
args.compute_dtype,
124+
f"{NUM_NODES}nodes",
125+
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
126+
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
127+
]
128+
)
129+
130+
executor = slurm_executor(
131+
args.account,
132+
args.partition,
133+
args.log_dir,
134+
NUM_NODES,
135+
NUM_GPUS_PER_NODE,
136+
args.time_limit,
137+
args.container_image,
138+
custom_mounts=[],
139+
custom_env_vars={},
140+
retries=0,
141+
)
142+
143+
recipe = llama3_405b_performance_recipe(
144+
args.compute_dtype,
145+
NUM_NODES,
146+
NUM_GPUS_PER_NODE,
147+
MICRO_BATCH_SIZE,
148+
GLOBAL_BATCH_SIZE,
149+
TP_SIZE,
150+
PP_SIZE,
151+
CP_SIZE,
152+
VP_SIZE,
153+
MAX_STEPS,
154+
)
155+
156+
if not args.tensorboard: # tensorboard adds performance overhead.
157+
recipe.log.tensorboard = None
158+
recipe.trainer.logger = False
159+
else:
160+
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
161+
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
162+
recipe.log.log_dir = "/nemo_run/lightning_logs"
163+
164+
plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)]
165+
if args.enable_profiling:
166+
plugins.append(NsysPlugin(start_step=5, end_step=6))
167+
168+
with run.Experiment(exp_name) as exp:
169+
exp.add(
170+
recipe,
171+
executor=executor,
172+
name=exp_name,
173+
plugins=plugins,
174+
)
175+
176+
if not args.dryrun:
177+
exp.run(sequential=True, detach=True)
178+
else:
179+
exp.dryrun()
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import nemo_run as run
18+
from nemo_run.config import NEMORUN_HOME
19+
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor
20+
21+
from nemo.collections.llm.recipes.llama3_70b import pretrain_recipe
22+
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
23+
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
24+
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
25+
from nemo.utils import logging
26+
27+
NUM_NODES = 8
28+
NUM_GPUS_PER_NODE = 8
29+
MICRO_BATCH_SIZE = 1
30+
GLOBAL_BATCH_SIZE = 128
31+
TP_SIZE = 4
32+
PP_SIZE = 4
33+
CP_SIZE = 2
34+
VP_SIZE = 5
35+
MAX_STEPS = 100
36+
37+
38+
def llama3_70b_performance_recipe(
39+
compute_dtype: str,
40+
num_nodes: int,
41+
num_gpus_per_node: int,
42+
mbs: int,
43+
gbs: int,
44+
tp_size: int,
45+
pp_size: int,
46+
cp_size: int,
47+
vp_size: Optional[int],
48+
max_steps: int,
49+
):
50+
"""
51+
llama3 70b pre-train recipe aimed at achieving best possible performance.
52+
53+
NOTE: Use fp8 precision training with caution. It might not give desirable results.
54+
"""
55+
recipe = pretrain_recipe(performance_mode=True)
56+
57+
# data module configs
58+
recipe.data.micro_batch_size = mbs
59+
recipe.data.global_batch_size = gbs
60+
recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run
61+
recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-70B")
62+
63+
recipe.trainer.max_steps = max_steps
64+
recipe.trainer.num_nodes = num_nodes
65+
recipe.trainer.devices = num_gpus_per_node
66+
67+
# parallelism configs
68+
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
69+
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
70+
recipe.trainer.strategy.context_parallel_size = cp_size
71+
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
72+
if tp_size > 1:
73+
recipe.trainer.strategy.sequence_parallel = True
74+
else:
75+
recipe.trainer.strategy.sequence_parallel = False
76+
77+
comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)
78+
79+
# compute dtype configs
80+
if compute_dtype.lower() == "fp8":
81+
recipe.trainer.plugins = bf16_with_fp8_mixed()
82+
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True
83+
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True
84+
85+
recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype
86+
87+
# callback configs
88+
garbage_collection_callback = run.Config(
89+
GarbageCollectionCallback,
90+
gc_interval_train=100,
91+
gc_interval_val=500,
92+
)
93+
recipe.trainer.callbacks.extend(
94+
[
95+
garbage_collection_callback,
96+
]
97+
)
98+
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
99+
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
100+
if comm_overlap_callback_idx >= 0:
101+
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True
102+
103+
# Misc. for overall faster experiment runtime
104+
recipe.log.ckpt = None
105+
recipe.trainer.enable_checkpointing = False
106+
recipe.trainer.val_check_interval = max_steps * gbs / dp_size
107+
recipe.trainer.log_every_n_steps = 1
108+
109+
return recipe
110+
111+
112+
if __name__ == "__main__":
113+
args = parse_cli_args().parse_args()
114+
if args.log_dir != NEMORUN_HOME:
115+
import sys
116+
117+
logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
118+
sys.exit(1)
119+
120+
exp_name = "_".join(
121+
[
122+
f"llama3_70b",
123+
args.compute_dtype,
124+
f"{NUM_NODES}nodes",
125+
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
126+
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
127+
]
128+
)
129+
130+
executor = slurm_executor(
131+
args.account,
132+
args.partition,
133+
args.log_dir,
134+
NUM_NODES,
135+
NUM_GPUS_PER_NODE,
136+
args.time_limit,
137+
args.container_image,
138+
custom_mounts=[],
139+
custom_env_vars={},
140+
retries=0,
141+
)
142+
143+
recipe = llama3_70b_performance_recipe(
144+
args.compute_dtype,
145+
NUM_NODES,
146+
NUM_GPUS_PER_NODE,
147+
MICRO_BATCH_SIZE,
148+
GLOBAL_BATCH_SIZE,
149+
TP_SIZE,
150+
PP_SIZE,
151+
CP_SIZE,
152+
VP_SIZE,
153+
MAX_STEPS,
154+
)
155+
156+
if not args.tensorboard: # tensorboard adds performance overhead.
157+
recipe.log.tensorboard = None
158+
recipe.trainer.logger = False
159+
else:
160+
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
161+
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
162+
recipe.log.log_dir = "/nemo_run/lightning_logs"
163+
164+
plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)]
165+
if args.enable_profiling:
166+
plugins.append(NsysPlugin(start_step=5, end_step=6))
167+
168+
with run.Experiment(exp_name) as exp:
169+
exp.add(
170+
recipe,
171+
executor=executor,
172+
name=exp_name,
173+
plugins=plugins,
174+
)
175+
176+
if not args.dryrun:
177+
exp.run(sequential=True, detach=True)
178+
else:
179+
exp.dryrun()

0 commit comments

Comments
 (0)