-
Notifications
You must be signed in to change notification settings - Fork 408
Optimize CP loss communication #1102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimize CP loss communication #1102
Conversation
…arallelism Restore Qwen2.5 model in external rollout test
…nges Handle CP OPSM masks centrally and restore loss guardrails
…source Simplify context-parallel seq KL helper
…d-clarity-issues Fix OPSM CP reduction and clean up interfaces
…stallation Prefer binary installs in build script
…nto modify-build_conda.sh-for-direct-installation-wkwuyr
…stallation-wkwuyr Adjust flash-attn wheel selection for torch pin
90fba9b to
ca4b729
Compare
|
@PopSoda2002 Thanks a lot for helping with the review, or please let me know who I should reach out to for it. |
|
This PR only optimized Megatron CP performance, and for FSDP it only make it compatible with the change. Is my understanding correct? |
Yes. Does FSDP also need optimization? I’ll take a look. |
We can do FSDP in next PR potentially. Otherwise the PR would be too large to review. |
Summary
This PR optimises the computation of context‑parallel (CP) loss by avoiding full all‑gathers when multiple CP ranks are used. It introduces sequence‑level KL preparation and adapts the computation depending on whether cp_size equals 1 or exceeds it.
Key Changes
Sequence‑level KL preparation: When cp_size > 1, log‑probability and mask segments from each rank are used to compute per‑rank partial KL values, which are then aggregated. This avoids gathering the entire log‑probability tensor across ranks.
Explicit CP metadata: The code now explicitly determines cp_size, cp_rank, and cp_group. If multiple ranks are present, these are used to coordinate partial computations; if cp_size is 1, the logic simply uses the local tensors.
KL divergence formula: For a sequence j with prompt length p_j and total length T_j, the sequence‑level KL is defined as
In this expression:
p_j denotes the prompt length of sequence j.
T_j is the total length of sequence j.
log p_new_j(t) and log p_old_j(t) are the log‑probabilities under the new and old policies, respectively, at position t in sequence j.
mask_j(t) is a mask value (typically 0 or 1) that indicates whether the token at position t should contribute to the KL sum.
When cp_size == 1, this sum is computed directly on the entire sequence. When cp_size > 1, each rank computes the sum over its local segment (adjusted by its token offset), and the partial sums from all ranks are added together to recover the same result as a full computation.
OPSM inputs: Added logic to generate OPSM inputs from sequence‑level KLs when OPSM is enabled. This uses the computed seq_kls and chunked_loss_masks.
No unnecessary communication: By distinguishing between cp_size == 1 and cp_size > 1, the PR ensures that data is only communicated across ranks when needed, preserving correctness while reducing overhead.
Motivation
#1062 (comment)
The previous implementation gathered full log‑probability tensors across all context‑parallel ranks regardless of configuration, which was inefficient. By computing sequence‑level KLs locally and aggregating only the necessary values, this PR reduces communication overhead and makes loss computation scale more gracefully with the number of CP ranks. It maintains mathematical equivalence with the single‑rank calculation thanks to the linearity of the KL formula.
Impact
These changes should improve training performance in multi‑rank CP configurations and provide clearer, more explicit handling of context‑parallel metadata. The inclusion of the KL formula clarifies the computation and demonstrates that splitting the sum across ranks yields the same result as computing it in one pass when cp_size equals 1.
From a theoretical perspective, the speedup comes directly from reducing the communication volume before synchronization. A standard communication cost model writes the latency as$$T \approx \alpha + \beta \cdot \text{bytes} $$ , where $$\alpha$$ is the fixed startup and synchronization overhead and $$\beta$$ reflects the inverse effective bandwidth. In the all-gather formulation, each rank must exchange token-level tensors whose size scales with the total sequence length $$L_{\text{total}}$$ , leading to a cost on the order of $$\alpha + \beta \cdot s \cdot B \cdot L_{\text{total}} $$ . In contrast, the sequence-level KL approach locally reduces over the token dimension first and only communicates one scalar per sequence, so the cost becomes $$\alpha + \beta \cdot s \cdot B$$ . The resulting theoretical optimization factor can therefore be approximated as
When sequence lengths are sufficiently large and the bandwidth term dominates, this ratio grows roughly with$$L_{\text{total}}$$ , explaining why eliminating token-level communication yields substantial speedups in practice.
Using
examples/reproducibility/run-qwen2.5-0.5B-gsm8k.shto run tests on 4 GPUs, the results show that the seq-KL approach is about twice as fast as all_gather, with the time reduced from 0.002 to 0.001.