Skip to content

Commit f396120

Browse files
authored
Enabling TP Comm Overlap and Packed Sequencing Configs for LLAMA3 70B… (#2247)
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: Raghav Hrishikeshan Mukundan <102543536+rhmukundan@users.noreply.github.com>
1 parent aa10ef7 commit f396120

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

scripts/performance/configs/llama/llama3_llm_finetune.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
209209
)
210210
set_llama3_common_peft_configs(cfg)
211211
set_workload_base_configs(cfg, base_cfg)
212+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
212213

213214
# Override target_modules to only apply LoRA to QKV
214215
cfg.peft.target_modules = ["linear_qkv"]
@@ -245,6 +246,12 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
245246
)
246247
set_llama3_common_peft_configs(cfg)
247248
set_workload_base_configs(cfg, base_cfg)
249+
# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
250+
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
251+
# for CUDA graphs and avoids NaN issues in attention kernels.
252+
cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True
253+
cfg.dataset.dataset_kwargs["pad_to_max_length"] = True
254+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
248255

249256
# Override target_modules to only apply LoRA to QKV
250257
cfg.peft.target_modules = ["linear_qkv"]
@@ -272,6 +279,7 @@ def llama3_70b_lora_config_h100(precision: str = "bf16", config_variant: str = "
272279
)
273280
set_llama3_common_peft_configs(cfg)
274281
set_workload_base_configs(cfg, base_cfg)
282+
cfg.comm_overlap = CommOverlapConfig(tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1))
275283

276284
# Override target_modules to only apply LoRA to QKV
277285
cfg.peft.target_modules = ["linear_qkv"]

0 commit comments

Comments
 (0)