@@ -209,6 +209,7 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
209209 )
210210 set_llama3_common_peft_configs (cfg )
211211 set_workload_base_configs (cfg , base_cfg )
212+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
212213
213214 # Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
214215 # This ensures consistent cu_seqlens tensor shapes across batches, which is required
@@ -239,6 +240,12 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
239240 )
240241 set_llama3_common_peft_configs (cfg )
241242 set_workload_base_configs (cfg , base_cfg )
243+ # Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
244+ # This ensures consistent cu_seqlens tensor shapes across batches, which is required
245+ # for CUDA graphs and avoids NaN issues in attention kernels.
246+ cfg .dataset .packed_sequence_specs .pad_cu_seqlens = True
247+ cfg .dataset .dataset_kwargs ["pad_to_max_length" ] = True
248+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
242249
243250 return cfg
244251
@@ -263,5 +270,6 @@ def llama3_70b_lora_config_h100(precision: str = "bf16", config_variant: str = "
263270 )
264271 set_llama3_common_peft_configs (cfg )
265272 set_workload_base_configs (cfg , base_cfg )
273+ cfg .comm_overlap = CommOverlapConfig (tp_comm_overlap = bool (cfg .model .tensor_model_parallel_size > 1 ))
266274
267275 return cfg
0 commit comments