@@ -209,6 +209,7 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
209209 )
210210 set_llama3_common_peft_configs (cfg )
211211 set_workload_base_configs (cfg , base_cfg )
212+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
212213
213214 # Override target_modules to only apply LoRA to QKV
214215 cfg .peft .target_modules = ["linear_qkv" ]
@@ -245,6 +246,12 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
245246 )
246247 set_llama3_common_peft_configs (cfg )
247248 set_workload_base_configs (cfg , base_cfg )
249+ # Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
250+ # This ensures consistent cu_seqlens tensor shapes across batches, which is required
251+ # for CUDA graphs and avoids NaN issues in attention kernels.
252+ cfg .dataset .packed_sequence_specs .pad_cu_seqlens = True
253+ cfg .dataset .dataset_kwargs ["pad_to_max_length" ] = True
254+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
248255
249256 # Override target_modules to only apply LoRA to QKV
250257 cfg .peft .target_modules = ["linear_qkv" ]
@@ -272,6 +279,7 @@ def llama3_70b_lora_config_h100(precision: str = "bf16", config_variant: str = "
272279 )
273280 set_llama3_common_peft_configs (cfg )
274281 set_workload_base_configs (cfg , base_cfg )
282+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
275283
276284 # Override target_modules to only apply LoRA to QKV
277285 cfg .peft .target_modules = ["linear_qkv" ]
0 commit comments