Skip to content

Commit 28be04b

Browse files
authored
Make VBoost activation conditional (#14458)
1 parent 310a862 commit 28be04b

32 files changed

+304
-271
lines changed

scripts/performance/argument_parser.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,5 +393,13 @@ def list_of_strings(arg):
393393
required=False,
394394
default=None,
395395
)
396+
parser.add_argument(
397+
"-vb",
398+
"--enable_vboost",
399+
help="Enable VBoost which steers more power towards tensor cores. Disabled by default",
400+
type=bool_arg,
401+
required=False,
402+
default=None,
403+
)
396404

397405
return parser

scripts/performance/diffusion/pretrain_flux_12b.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
import nemo_run as run
1818

1919
from nemo.collections.diffusion.recipes.flux_12b import pretrain_recipe
20-
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
20+
from nemo.lightning.run.plugins import NsysPlugin
2121

2222
from ..argument_parser import parse_cli_args
2323
from ..executors import slurm_executor
24-
from ..helpers import args_sanity_check, get_user_configs, set_exp_logging_configs, set_primary_perf_configs
24+
from ..helpers import (
25+
args_sanity_check,
26+
build_perf_env_plugin,
27+
get_user_configs,
28+
set_exp_logging_configs,
29+
set_primary_perf_configs,
30+
)
2531

2632

2733
def override_recipe_configs(
@@ -106,13 +112,7 @@ def override_recipe_configs(
106112
nemo_home=args.nemo_home,
107113
)
108114

109-
plugins = [
110-
PerfEnvPlugin(
111-
enable_vboost=True,
112-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
113-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
114-
),
115-
]
115+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
116116
if args.enable_nsys:
117117
plugins.append(NsysPlugin(start_step=5, end_step=6))
118118

scripts/performance/helpers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,32 @@ def args_sanity_check(args: dict) -> None:
483483
assert args.wandb_key is not None, "wandb logger needs \"wandb_key\""
484484
assert args.wandb_prj_name is not None, "wandb logger needs \"wandb_prj_name\""
485485
assert args.wandb_job_name is not None, "wandb logger needs \"wandb_job_name\""
486+
487+
488+
def build_perf_env_plugin(args, pp_size: int | None = None, user_buffer_registration: Optional[bool] = None):
489+
"""
490+
Create a PerfEnvPlugin with consistent defaults across scripts.
491+
492+
- enable_vboost only when gpu is h100
493+
- set nccl_pp_comm_chunksize when pipeline parallelism is used
494+
- set gpu_sm100_or_newer when gpu is in ['b200', 'gb200']
495+
496+
Args:
497+
args: Parsed CLI args that include `gpu`.
498+
pp_size: Pipeline parallel size to decide comm chunk size.
499+
user_buffer_registration: Optional flag to enable user buffer registration.
500+
"""
501+
from nemo.lightning.run.plugins import PerfEnvPlugin
502+
503+
gpu_str = getattr(args, "gpu", "").lower()
504+
enable_vboost = args.enable_vboost
505+
gpu_sm100_or_newer = gpu_str in ["b200", "gb200"]
506+
nccl_pp_comm_chunksize = 2097152 if (pp_size is not None and pp_size > 1) else None
507+
user_buf = bool(user_buffer_registration) if user_buffer_registration is not None else False
508+
509+
return PerfEnvPlugin(
510+
enable_vboost=enable_vboost,
511+
nccl_pp_comm_chunksize=nccl_pp_comm_chunksize,
512+
gpu_sm100_or_newer=gpu_sm100_or_newer,
513+
user_buffer_registration=user_buf,
514+
)

scripts/performance/llm/finetune_deepseek_v3.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
2323
from nemo.lightning.pytorch.callbacks.megatron_enable_experimental_callback import MegatronEnableExperimentalCallback
2424
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
25-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
25+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2626

2727
from ..argument_parser import parse_cli_args
2828
from ..executors import slurm_executor
29-
from ..helpers import args_sanity_check, get_user_configs, set_primary_perf_configs
29+
from ..helpers import args_sanity_check, build_perf_env_plugin, get_user_configs, set_primary_perf_configs
3030
from ..utils import hf_tokenizer, import_ckpt_experiment, isfile_train_pack_metadata
3131

3232
HF_MODEL_URI = "deepseek-ai/DeepSeek-V3-Base"
@@ -167,13 +167,7 @@ def override_recipe_configs(
167167
network='sharp' if args.use_sharp else None,
168168
)
169169

170-
plugins = [
171-
PerfEnvPlugin(
172-
enable_vboost=True,
173-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
174-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
175-
)
176-
]
170+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
177171
if args.enable_nsys:
178172
plugins.append(NsysPlugin(start_step=10, end_step=12, gen_shape=True))
179173
if args.enable_memory_profile:

scripts/performance/llm/finetune_llama31_405b.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@
2222
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
2323
userbuffers_fp8_h100_h16384_tp4_mbs1_seqlen2048_lora,
2424
)
25-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
25+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2626

2727
from ..argument_parser import parse_cli_args
2828
from ..executors import slurm_executor
29-
from ..helpers import args_sanity_check, get_user_configs, set_exp_logging_configs, set_primary_perf_configs
29+
from ..helpers import (
30+
args_sanity_check,
31+
build_perf_env_plugin,
32+
get_user_configs,
33+
set_exp_logging_configs,
34+
set_primary_perf_configs,
35+
)
3036
from ..utils import (
3137
get_comm_overlap_callback_idx,
3238
hf_tokenizer,
@@ -190,13 +196,7 @@ def override_recipe_configs(
190196
network='sharp' if args.use_sharp else None,
191197
)
192198

193-
plugins = [
194-
PerfEnvPlugin(
195-
enable_vboost=True,
196-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
197-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
198-
)
199-
]
199+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
200200
if args.enable_nsys:
201201
plugins.append(NsysPlugin(start_step=5, end_step=6))
202202
if args.enable_memory_profile:

scripts/performance/llm/finetune_llama3_70b.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@
2222
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
2323
userbuffers_fp8_h100_h8192_tp2_mbs1_seqlen4096_lora,
2424
)
25-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
25+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2626

2727
from ..argument_parser import parse_cli_args
2828
from ..executors import slurm_executor
29-
from ..helpers import args_sanity_check, get_user_configs, set_exp_logging_configs, set_primary_perf_configs
29+
from ..helpers import (
30+
args_sanity_check,
31+
build_perf_env_plugin,
32+
get_user_configs,
33+
set_exp_logging_configs,
34+
set_primary_perf_configs,
35+
)
3036
from ..utils import (
3137
get_comm_overlap_callback_idx,
3238
hf_tokenizer,
@@ -197,13 +203,7 @@ def override_recipe_configs(
197203
network='sharp' if args.use_sharp else None,
198204
)
199205

200-
plugins = [
201-
PerfEnvPlugin(
202-
enable_vboost=True,
203-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
204-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
205-
)
206-
]
206+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
207207
if args.enable_nsys:
208208
plugins.append(NsysPlugin(start_step=5, end_step=6))
209209
if args.enable_memory_profile:

scripts/performance/llm/finetune_llama3_8b.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
import nemo_run as run
1818

1919
from nemo.collections.llm.recipes.llama3_8b import finetune_recipe, model
20-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
20+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2121

2222
from ..argument_parser import parse_cli_args
2323
from ..executors import slurm_executor
24-
from ..helpers import args_sanity_check, get_user_configs, set_exp_logging_configs, set_primary_perf_configs
24+
from ..helpers import (
25+
args_sanity_check,
26+
build_perf_env_plugin,
27+
get_user_configs,
28+
set_exp_logging_configs,
29+
set_primary_perf_configs,
30+
)
2531
from ..utils import hf_tokenizer, import_ckpt_experiment, prepare_squad_dataset_experiment
2632

2733
HF_MODEL_URI = "meta-llama/Meta-Llama-3-8B"
@@ -135,13 +141,7 @@ def override_recipe_configs(
135141
network='sharp' if args.use_sharp else None,
136142
)
137143

138-
plugins = [
139-
PerfEnvPlugin(
140-
enable_vboost=True,
141-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
142-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
143-
)
144-
]
144+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
145145
if args.enable_nsys:
146146
plugins.append(NsysPlugin(start_step=5, end_step=6))
147147
if args.enable_memory_profile:

scripts/performance/llm/finetune_llama4_e128.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,17 @@
1818

1919
from nemo.collections.llm.recipes.llama4_e128 import finetune_recipe, model
2020
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
21-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
21+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2222

2323
from ..argument_parser import parse_cli_args
24+
from ..executors import slurm_executor
25+
from ..helpers import (
26+
args_sanity_check,
27+
build_perf_env_plugin,
28+
get_user_configs,
29+
set_exp_logging_configs,
30+
set_primary_perf_configs,
31+
)
2432
from ..utils import (
2533
args_sanity_check,
2634
get_user_configs,
@@ -162,13 +170,7 @@ def override_recipe_configs(
162170
)
163171
exp_name = f"{splitext(basename(__file__))[0]}_{args.compute_dtype}_{exp_config}"
164172

165-
plugins = [
166-
PerfEnvPlugin(
167-
enable_vboost=True,
168-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
169-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
170-
)
171-
]
173+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
172174

173175
if args.enable_nsys:
174176
plugins.append(NsysPlugin(start_step=5, end_step=6))

scripts/performance/llm/mlperf_lora_llama2_70b.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from nemo.collections.llm.gpt.model.llama import *
2525
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
2626
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
27-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
27+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2828

2929
from ..argument_parser import parse_cli_args
3030
from ..executors import slurm_executor
31-
from ..helpers import args_sanity_check
31+
from ..helpers import args_sanity_check, build_perf_env_plugin
3232
from ..utils import import_ckpt_experiment
3333

3434
NUM_NODES = 1
@@ -345,13 +345,7 @@ def mlperf_lora_llama2_70b_recipe(
345345

346346
recipe.log.wandb = wandb_logger(project=args.wandb_prj_name, name=args.wandb_job_name)
347347

348-
plugins = [
349-
PerfEnvPlugin(
350-
enable_vboost=True,
351-
nccl_pp_comm_chunksize=2097152 if PP_SIZE > 1 else None,
352-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
353-
)
354-
]
348+
plugins = [build_perf_env_plugin(args, pp_size=PP_SIZE)]
355349
if args.enable_nsys:
356350
plugins.append(NsysPlugin(start_step=5, end_step=6))
357351
if args.enable_memory_profile:

scripts/performance/llm/pretrain_automodel_llama3_8b.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from nemo import lightning as nl
2020
from nemo.collections.llm.gpt.data.hf_dataset import HFMockDataModule
2121
from nemo.collections.llm.recipes import hf_auto_model_for_causal_lm
22-
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
22+
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin
2323

2424
from ..argument_parser import parse_cli_args
2525
from ..executors import slurm_executor
26-
from ..helpers import args_sanity_check, get_user_configs
26+
from ..helpers import args_sanity_check, build_perf_env_plugin, get_user_configs
2727

2828
SEQ_LENGTH = 2048
2929
NUM_GPUS_PER_NODE = 8
@@ -105,13 +105,7 @@ def override_recipe_configs(
105105
network='sharp' if args.use_sharp else None,
106106
)
107107

108-
plugins = [
109-
PerfEnvPlugin(
110-
enable_vboost=True,
111-
nccl_pp_comm_chunksize=2097152 if pp_size > 1 else None,
112-
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
113-
),
114-
]
108+
plugins = [build_perf_env_plugin(args, pp_size=pp_size)]
115109
if args.enable_nsys:
116110
plugins.append(NsysPlugin(start_step=5, end_step=6))
117111
if args.enable_memory_profile:

0 commit comments

Comments
 (0)