Skip to content

Commit d7af5ae

Browse files
rhmukundanNeMo Bot
authored andcommitted
Enabling TP Comm Overlap and Packed Sequencing Configs for LLAMA3 70B… (#2247)
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: Raghav Hrishikeshan Mukundan <102543536+rhmukundan@users.noreply.github.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
1 parent b10d7e3 commit d7af5ae

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

scripts/performance/configs/llama/llama3_llm_finetune.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
209209
)
210210
set_llama3_common_peft_configs(cfg)
211211
set_workload_base_configs(cfg, base_cfg)
212+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
212213

213214
# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
214215
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
@@ -239,6 +240,12 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
239240
)
240241
set_llama3_common_peft_configs(cfg)
241242
set_workload_base_configs(cfg, base_cfg)
243+
# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
244+
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
245+
# for CUDA graphs and avoids NaN issues in attention kernels.
246+
cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True
247+
cfg.dataset.dataset_kwargs["pad_to_max_length"] = True
248+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
242249

243250
return cfg
244251

@@ -263,5 +270,6 @@ def llama3_70b_lora_config_h100(precision: str = "bf16", config_variant: str = "
263270
)
264271
set_llama3_common_peft_configs(cfg)
265272
set_workload_base_configs(cfg, base_cfg)
273+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
266274

267275
return cfg

0 commit comments

Comments
 (0)