Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
808ecf7
Optimize CP loss communication
cklxx Dec 12, 2025
7b81652
Batch CP reductions for seq KL and OPSM
cklxx Dec 12, 2025
4026027
Reduce CP collectives for KL and OPSM
cklxx Dec 12, 2025
94dd6e3
Add CP validation guide in Chinese
cklxx Dec 12, 2025
25fb9e3
Fix lint issues in CP loss utilities
cklxx Dec 12, 2025
3c835af
upd
cklxx Dec 12, 2025
47415ac
Merge pull request #8 from THUDM/main
cklxx Dec 21, 2025
739cafc
Restore external rollout test model
cklxx Dec 22, 2025
39eb031
Merge pull request #9 from cklxx/optimize-training-time-for-context-p…
cklxx Dec 22, 2025
4a8bd24
Delete docs/zh/advanced/cp_validation.md
cklxx Dec 22, 2025
6cb5c2e
Update index.rst
cklxx Dec 22, 2025
873fe1d
Refine CP OPSM handling and loss safety
cklxx Dec 22, 2025
d398b45
Merge pull request #10 from cklxx/review-performance-optimization-cha…
cklxx Dec 22, 2025
80db40b
Simplify context-parallel seq KL helper
cklxx Dec 22, 2025
8687913
Merge pull request #11 from cklxx/fix-design-issues-in-loss-and-data-…
cklxx Dec 22, 2025
232d9c1
Fix OPSM CP reduction and clean up interfaces
cklxx Dec 22, 2025
06787cf
Merge pull request #13 from cklxx/fix-compute_opsm_mask-redundancy-an…
cklxx Dec 22, 2025
12ff9a2
Merge branch 'main' into codex/optimize-training-time-for-context-par…
cklxx Dec 22, 2025
677bb64
Optimize CP seq-KL reduction and OPSM
cklxx Dec 22, 2025
8db2cf5
Prefer binary installs in build script
cklxx Dec 22, 2025
8ffca92
Merge pull request #15 from cklxx/modify-build_conda.sh-for-direct-in…
cklxx Dec 22, 2025
8b83984
Make flash-attn wheel match torch pin
cklxx Dec 22, 2025
a1865f4
Merge branch 'codex/optimize-training-time-for-context-parallelism' i…
cklxx Dec 22, 2025
afad254
Merge pull request #16 from cklxx/modify-build_conda.sh-for-direct-in…
cklxx Dec 22, 2025
1365e99
upd
cklxx Dec 23, 2025
7adec05
upd
cklxx Dec 23, 2025
9355324
Fix CP collectives and empty value loss
cklxx Dec 23, 2025
558e19b
Revert build_conda.sh to main
cklxx Dec 23, 2025
9792e0c
Add perf timer for CP seq-KL prep
cklxx Dec 24, 2025
1cfa03d
upd
cklxx Dec 27, 2025
a01178a
upd
cklxx Dec 27, 2025
c61fcbd
upd
cklxx Dec 27, 2025
04949ce
Merge branch 'main' into codex/optimize-training-time-for-context-par…
cklxx Jan 5, 2026
430f9e8
Refactor CP helper arg passing
cklxx Jan 5, 2026
ca4b729
upd
cklxx Jan 5, 2026
bfae19b
upd
cklxx Jan 5, 2026
1d6ee30
upd
cklxx Jan 6, 2026
53c9bae
Merge branch 'main' into codex/optimize-training-time-for-context-par…
cklxx Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@
from slime.utils.logging_utils import init_tracking
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.metric_utils import compute_rollout_step
from slime.utils.misc import Box
from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
from slime.utils.misc import load_function, Box
from slime.utils.ppo_utils import (
build_opsm_inputs_from_log_probs,
compute_approx_kl,
compute_gspo_kl,
compute_opsm_mask,
compute_policy_loss,
vanilla_tis_function,
)
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.profile_utils import TrainProfiler
from slime.utils.timer import Timer, inverse_timer, timer, with_defer
Expand Down Expand Up @@ -582,13 +589,16 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
ppo_kl = old_log_probs - log_probs

if self.args.use_opsm:
opsm_mask, opsm_clipfrac = compute_opsm_mask(
args=self.args,
opsm_inputs = build_opsm_inputs_from_log_probs(
full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches],
full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches],
advantages=[batch["advantages"] for batch in unpacked_batches],
loss_masks=loss_masks,
)
opsm_mask, opsm_clipfrac = compute_opsm_mask(
args=self.args,
advantages=[batch["advantages"] for batch in unpacked_batches],
opsm_inputs=opsm_inputs,
)

if self.args.advantage_estimator == "gspo":
ppo_kl = compute_gspo_kl(
Expand Down
97 changes: 61 additions & 36 deletions slime/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@ def get_logits_and_tokens_offset_with_cp(
return chunk_size, (chunk_0, chunk_1), (logits_0, logits_1), (token_0, token_1)


def get_chunked_loss_masks(
total_lengths: list[int],
response_lengths: list[int],
loss_masks: list[torch.Tensor],
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
) -> tuple[list[torch.Tensor], list[int]]:
"""Slice loss masks to the local CP segments and return chunk lengths."""

cp_size = mpu.get_context_parallel_world_size()
if cp_size == 1:
return loss_masks, response_lengths

chunked_loss_masks: list[torch.Tensor] = []
chunk_lengths: list[int] = []
for i, (total_length, response_length, loss_mask) in enumerate(zip(total_lengths, response_lengths, loss_masks, strict=False)):
max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None
prompt_length = total_length - response_length
_, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(
total_length,
response_length,
qkv_format,
max_seq_len,
)

local_chunks: list[torch.Tensor] = []
for start, end in tokens_offset:
local_start, local_end = start - prompt_length, end - prompt_length
if local_end > local_start:
local_chunks.append(loss_mask[local_start:local_end])

if local_chunks:
chunked_mask = torch.cat(local_chunks, dim=0)
else:
chunked_mask = loss_mask.new_zeros((0,))

chunked_loss_masks.append(chunked_mask)
chunk_lengths.append(chunked_mask.size(0))

return chunked_loss_masks, chunk_lengths


def get_sum_of_sample_mean(
total_lengths: list[int],
response_lengths: list[int],
Expand All @@ -62,59 +104,39 @@ def get_sum_of_sample_mean(
Calculate correct sample mean for CP
"""
cp_size = mpu.get_context_parallel_world_size()
chunked_loss_masks, chunk_lengths = get_chunked_loss_masks(
total_lengths,
response_lengths,
loss_masks,
qkv_format,
max_seq_lens,
)

if cp_size == 1:

def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
)
return sum([(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)])

def sum_of_token(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum()
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
)
return sum([(x_i * loss_mask_i).sum() for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)])

else:
cp_chunk_lengths = []
chunked_loss_masks = []
for i, (total_length, response_length, loss_mask) in enumerate(
zip(total_lengths, response_lengths, loss_masks, strict=False)
):
max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None
prompt_length = total_length - response_length
_, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(
total_length, response_length, qkv_format, max_seq_len
)
loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length]
loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length]
chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0))
cp_chunk_lengths.append(chunked_loss_masks[i].size(0))

def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for x_i, chunked_loss_mask, loss_mask in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=False
x.split(chunk_lengths, dim=0),
chunked_loss_masks,
loss_masks,
strict=False,
)
]
)

def sum_of_token(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * chunked_loss_mask).sum()
for x_i, chunked_loss_mask in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, strict=False
)
]
)
return sum([(x_i * chunked_loss_mask).sum() for x_i, chunked_loss_mask in zip(x.split(chunk_lengths, dim=0), chunked_loss_masks, strict=False)])

return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token

Expand Down Expand Up @@ -231,7 +253,10 @@ def slice_log_prob_with_cp(

prompt_length = total_length - response_length
_, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(
total_length, response_length, qkv_format, max_token_len
total_length,
response_length,
qkv_format,
max_token_len,
)

chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)]
Expand Down
Loading