Skip to content

Conversation

@cklxx
Copy link
Contributor

@cklxx cklxx commented Dec 12, 2025

Summary

This PR optimises the computation of context‑parallel (CP) loss by avoiding full all‑gathers when multiple CP ranks are used. It introduces sequence‑level KL preparation and adapts the computation depending on whether cp_size equals 1 or exceeds it.

Key Changes

Sequence‑level KL preparation: When cp_size > 1, log‑probability and mask segments from each rank are used to compute per‑rank partial KL values, which are then aggregated. This avoids gathering the entire log‑probability tensor across ranks.

Explicit CP metadata: The code now explicitly determines cp_size, cp_rank, and cp_group. If multiple ranks are present, these are used to coordinate partial computations; if cp_size is 1, the logic simply uses the local tensors.

KL divergence formula: For a sequence j with prompt length p_j and total length T_j, the sequence‑level KL is defined as

$$seq_kl_j = sum_{t = p_j}^{T_j - 1} (log p_new_j(t) - log p_old_j(t)) * mask_j(t)$$

In this expression:
p_j denotes the prompt length of sequence j.
T_j is the total length of sequence j.
log p_new_j(t) and log p_old_j(t) are the log‑probabilities under the new and old policies, respectively, at position t in sequence j.
mask_j(t) is a mask value (typically 0 or 1) that indicates whether the token at position t should contribute to the KL sum.

When cp_size == 1, this sum is computed directly on the entire sequence. When cp_size > 1, each rank computes the sum over its local segment (adjusted by its token offset), and the partial sums from all ranks are added together to recover the same result as a full computation.

OPSM inputs: Added logic to generate OPSM inputs from sequence‑level KLs when OPSM is enabled. This uses the computed seq_kls and chunked_loss_masks.

No unnecessary communication: By distinguishing between cp_size == 1 and cp_size > 1, the PR ensures that data is only communicated across ranks when needed, preserving correctness while reducing overhead.

Motivation
#1062 (comment)

The previous implementation gathered full log‑probability tensors across all context‑parallel ranks regardless of configuration, which was inefficient. By computing sequence‑level KLs locally and aggregating only the necessary values, this PR reduces communication overhead and makes loss computation scale more gracefully with the number of CP ranks. It maintains mathematical equivalence with the single‑rank calculation thanks to the linearity of the KL formula.

Impact

These changes should improve training performance in multi‑rank CP configurations and provide clearer, more explicit handling of context‑parallel metadata. The inclusion of the KL formula clarifies the computation and demonstrates that splitting the sum across ranks yields the same result as computing it in one pass when cp_size equals 1.

From a theoretical perspective, the speedup comes directly from reducing the communication volume before synchronization. A standard communication cost model writes the latency as $$T \approx \alpha + \beta \cdot \text{bytes} $$, where $$\alpha$$ is the fixed startup and synchronization overhead and $$\beta$$ reflects the inverse effective bandwidth. In the all-gather formulation, each rank must exchange token-level tensors whose size scales with the total sequence length $$L_{\text{total}}$$, leading to a cost on the order of $$\alpha + \beta \cdot s \cdot B \cdot L_{\text{total}} $$. In contrast, the sequence-level KL approach locally reduces over the token dimension first and only communicates one scalar per sequence, so the cost becomes $$\alpha + \beta \cdot s \cdot B$$. The resulting theoretical optimization factor can therefore be approximated as

$$T_{\text{allgather}} / T_{\text{seqkl}} \approx (\alpha + \beta \cdot s \cdot B \cdot L_{\text{total}}) / (\alpha + \beta \cdot s \cdot B)$$.

When sequence lengths are sufficiently large and the bandwidth term dominates, this ratio grows roughly with $$L_{\text{total}}$$, explaining why eliminating token-level communication yields substantial speedups in practice.

Using examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh to run tests on 4 GPUs, the results show that the seq-KL approach is about twice as fast as all_gather, with the time reduced from 0.002 to 0.001.

cp2 = True
(MegatronfrainkayActor pid=38إن المؤسسة الما والمان المؤسمةimer.py.24 = fimer train_waitStart
(MegatronTrainRayActor pid=38610) [2026-01-05 12:20:29] train_metric_utils.py:44 - perf 2: {'perf/sleep_time': 4.035936594009399, 'perf/update_weighnts_time': 0.3353748321533203, 'perf/wake_up_ti
me': 1.9440078735351562, 'perf/data_preprocess_time': 0.002149343490600586, 'perf/train_wait_time': 11.3338740110397339, 'perf/ref_log_probs_time':_1.4039483070373535, 'perf/log_probs_time': 0.10
894346237182617, 'perf/cp_seq_kl_prep_time': 0.0009074211120605469, 'perf/actor_train_time':_1.02477359777172852, 'perf/train_timeT: 2.8442583084106445, 'perf/log_probs_tflops':_11.64935490539547
, 'perf/ref_log_probs_tflops': 0.9039656598682972, 'pperf/actor_train_tflops': 3.7153212981452874, 'perf/actor_train_tok_per_s': 2355.6422661329875, 'perf/step_time': 14.182998418807983, 'perf/wa
it_time_ratio':_0.7994600137133985}

(MegatronTrainRayActor pid=38610) [2026-01-05 12:21:16] train_metric_utils.py:44 - perf 5: {'perf/sleep_time': 6.253698110580444, 'perf/update_wveights_time': 0.3034975528717041, 'perf/wake_up_Lti
me': 2.5065340995788574, 'perf/data_preprocess_time': 0.004125833511352539, 'perf/train_wait_time': 13.5833440065383911, 'perf/ref_log_probs_time':_1.7108128070831299, 'perf/log_probs_time': 0.22
948408126831055, 'perf/cp_seq_kl_prep_time': 0.0014464855194091797, 'perf/actor_train_time':_1.14035391809755615, 'perf/train_timeT: 3.1958839893341064, 'perf/log_probs_tflops':_5.450082672120884
'perf/ref_log_probs_tflops': 0.7310602362045722, 'perf/acctor_train_tflops': 3.2903132835076376, 'perf/actor_train_tok_per_s': 2087.948278388983, 'perf/step_time': 16.779324054718018, 'perf/wai
t_time_ratio': 0.8095344020466971}

(MegatronTrainRayActor pid=38610) [2026-01-05 12:22:552] train_metric_utils.py:44 = perf 11: {'perf/sleep_time': 6.354093074798584, 'perf/update_weights_time': 0.37139129638671875, 'perf/wake_up_
time': 2.7856333255767822, 'perf/data_preprocess_time': 0.003217935562133789, 'perf/train_wait_time': 14108158349990845, 'perf/ref_log_probs_time': 1.13242667387390137, 'perf/log_probs_time': 0
1786973476409912, 'perf/cp_seq_kl_prep_time': 0.00170075538635253906, 'perf/actor_train_time': 1.0746178622701416, 'perf/train_time':_2.5313310623168945, 'perf/log_probs_tflops': 7.096067806017753
perf/ref_log_probs_tflops': 1.1197620580983496, 'perf/actor_train_tflops': 3.5399983741987984, 'perf/actor_train_tok_per_s': 2244.518804048744, 'perf/step_tiime': 16.63948941230774, 'perf/wait
time_ratio': 0.8478720710958508}

(MegatronTrainRayActor pid=38610) [2026-01-05 12:26:48] train_metric_utils.py:44 - perf 28: {'perf/sleep_time': 3.8717112541198773, 'perf/update_weights_time': 0.33193492889404297, 'perf/wake_up_
time': 2.0198421478271484, 'perf/data_preprocess_timee': 0.001870870590209961, 'perf/train_wait_time': 11.433748483657837, 'perf/ref_log_probs_time': 0.8980967998504639, 'perf/log_probs_time': 0.
10044336318969727, 'perf/cp_seq_kl_prep_time': 0.0009725093841552734, 'perf/actor_train_time': 0.78676509985717773, 'perf/train_time': 2.0016472339630127, 'perf/log_probs_tflops': 12.534928645809
66, 'perf/ref_log_probs_tflops': 1.4019094497805094, 'pderf/actor_train_tflops': 4.800862644314931, 'perf/actor_train_tok_per_s': 3045.3816575614287, 'perf/step_time':13.43539571762085, 'perf/wa
it_time_ratio': 0.8510168754212574}

cp2 = False
(MegatronTrainRayActor pid=53217) [2026-01-05 12:43:224] train_metric_utils.py:44 = perf 25: {'perf/sleep__time': 3.710853099822998, 'perf/update_weights_time': 0.5368313789367676, 'perf/wake_up_t
ime': 1.8624987602233887, 'perf/data_preprocess_time': 0.0019109249114990234, 'perf/train_wait_time': 11.519951581954956, 'perf/ref_log_probs_time': 0.343092679977417, 'perf/log_probs_time': 0.1
7170119285583496, 'perf/cp_seq_kl_prep_time': 0.0024895668029785156, 'perf/actor_train_time': 0.6906688213348389, 'perf/train_time': 1.31532621383667, 'perf/log_probs_tflops': 7.453253218214382,
'perf/ref_log_probs_tflops': 3.7299905911960423, 'perf/actor_train_tflops': 5.558666159639401, 'perf/actor_train_tok_per_s': 3522.67240802589, 'perf/step_timeأ: 12.835277795791626, 'perf/wait_t
ime_ratio': 0.8975225752988429}

(MegatronTrainRayActor pid=53217) [2026-01-05 12:43:54] train_metric_utils.py:44 - perf 27: {'perf/sleep_time': 4.464445352554321, 'perf/update_weights_time': 0.32878923416137695, 'perf/wake_up_
time': 1.952591896057129, 'perf/data_preprocess_time': 0.0019826889038085938, 'perf/train_wait_time': 14.692675828933716, 'perf/ref_log_probs_time': 0.40941762924194336, 'perf/log_probs_time": 0
0947260856628418, 'perf/cp_seq_kl_prep_time': 0.0025134151077270508, 'perf/actor_train_time': 0.6804351806640625, 'perf/train_time': 1.2983801364898682, 'perf/log_probs_tflops': 13.0804304065743
24, 'perf/ref_log_probs_tflops': 3-0263913488390233, 'pperf/actor_train_tflops': 5.462936102116691, 'perf/actor_train_tok_per_sื': 3468.3685780279416, 'perf/step_time': 15.991055965423584, 'perf/w
ait_time_ratio': 0.9188058537661759}


(MegatronTrainRayActor pid=53217) [2026-01-05 12:44:220] train_metric_utils.py:44 - perf 29: {'perf/sleep_time': 3.905529737472534, 'perf/update_weights_time':0.32970690727233887, 'perf/wake_upj
time': 2.0948638916015625, 'perf/data_preprocess_timee': 0.00199127197265625, 'perf/train_wait_time': 11.27786922454834, 'perf/ref_log_probs_time': 0.28028798110333252, 'perf/log_probs_time': 0.09
743428230285645, 'perf/cp_seq_kl_prep_time': 0.0026300472183227539, 'perf/actor_train_time': 1.0780537128448486, 'perf/train_time': 1.6341447830200195, 'perf/log_probs_tflops': 12.922992866660508
, 'perf/ref_log_probs_tflops': 4.492317260718691, 'perf/actor_train_tflops': 3.5039326524240075, 'perf/actorr_train_tok_per_s': 2222.5237680201076, 'perf/step_time': 12.91201400756836, 'perf/wait
time_ratio': 0.8734399775230907}

cklxx and others added 30 commits December 12, 2025 16:01
…arallelism

Restore Qwen2.5 model in external rollout test
…nges

Handle CP OPSM masks centrally and restore loss guardrails
…source

Simplify context-parallel seq KL helper
…d-clarity-issues

Fix OPSM CP reduction and clean up interfaces
…stallation

Prefer binary installs in build script
…nto modify-build_conda.sh-for-direct-installation-wkwuyr
…stallation-wkwuyr

Adjust flash-attn wheel selection for torch pin
@cklxx cklxx force-pushed the codex/optimize-training-time-for-context-parallelism branch from 90fba9b to ca4b729 Compare January 5, 2026 12:46
@cklxx cklxx marked this pull request as ready for review January 5, 2026 12:46
@cklxx
Copy link
Contributor Author

cklxx commented Jan 5, 2026

@PopSoda2002 Thanks a lot for helping with the review, or please let me know who I should reach out to for it.

@Hecate0821
Copy link
Collaborator

This PR only optimized Megatron CP performance, and for FSDP it only make it compatible with the change. Is my understanding correct?

@cklxx
Copy link
Contributor Author

cklxx commented Jan 6, 2026

This PR only optimized Megatron CP performance, and for FSDP it only make it compatible with the change. Is my understanding correct?

Yes. Does FSDP also need optimization? I’ll take a look.

@Hecate0821
Copy link
Collaborator

This PR only optimized Megatron CP performance, and for FSDP it only make it compatible with the change. Is my understanding correct?

Yes. Does FSDP also need optimization? I’ll take a look.

We can do FSDP in next PR potentially. Otherwise the PR would be too large to review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants