Skip to content

Commit b74858a

Browse files
zijiexiaGuanxingLu
andauthored
[FSDP] Add Masked importance sampling (#1122)
Co-authored-by: Guanxing Lu <747398423@qq.com>
1 parent ece2624 commit b74858a

File tree

5 files changed

+281
-20
lines changed

5 files changed

+281
-20
lines changed

examples/train_infer_mismatch_helper/mis.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,45 @@ def add_ppl_metrics(
395395
rho_squared_seq = torch.exp(2.0 * log_ratio_sum_safe) # (Π ρ_t)²
396396
chi2_seq = rho_squared_seq - 1.0
397397
metrics_append(metrics, "chi2_seq", chi2_seq)
398+
399+
400+
def compute_mis_weights_fsdp(
401+
args,
402+
*,
403+
pg_loss: torch.Tensor,
404+
train_log_probs: list[torch.Tensor],
405+
rollout_log_probs: list[torch.Tensor],
406+
loss_masks: list[torch.Tensor],
407+
**kwargs: Any,
408+
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
409+
"""Compute masked importance sampling weights for FSDP. No context parallelism.
410+
411+
Args:
412+
args: Arguments containing MIS settings (use_tis, tis_mode, etc.)
413+
pg_loss: Policy gradient loss, flattened tensor [total_tokens]
414+
train_log_probs: Training log probs, list of 1D tensors per sequence
415+
rollout_log_probs: Rollout log probs, list of 1D tensors per sequence
416+
loss_masks: Loss masks, list of 1D tensors per sequence
417+
**kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility
418+
419+
Returns:
420+
pg_loss: Policy gradient loss with IS weights applied
421+
modified_masks: Modified loss masks after rejection sampling
422+
mis_metrics: Metrics dict with flattened tensors
423+
"""
424+
is_weights, modified_masks, is_metrics = compute_mis_weights(
425+
args=args,
426+
train_log_probs=train_log_probs,
427+
rollout_log_probs=rollout_log_probs,
428+
loss_masks=loss_masks,
429+
)
430+
431+
result_metrics = {}
432+
if is_weights is not None:
433+
is_weights_flat = torch.cat(is_weights, dim=0)
434+
pg_loss = pg_loss * is_weights_flat
435+
436+
for key, values in is_metrics.items():
437+
result_metrics[f"mis_{key}"] = torch.cat(values, dim=0)
438+
439+
return pg_loss, modified_masks, result_metrics
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
14+
15+
16+
set -ex
17+
18+
# will prevent ray from buffering stdout/stderr
19+
export PYTHONBUFFERED=16
20+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
21+
NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
22+
if [ "$NVLINK_COUNT" -gt 0 ]; then
23+
HAS_NVLINK=1
24+
else
25+
HAS_NVLINK=0
26+
fi
27+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
28+
29+
30+
31+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
32+
33+
RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"}
34+
LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints"
35+
36+
CKPT_ARGS=(
37+
--hf-checkpoint /root/Qwen3-4B
38+
--load /root/Qwen3-4B
39+
--ref-load /root/Qwen3-4B
40+
)
41+
42+
ROLLOUT_ARGS=(
43+
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
44+
--input-key prompt
45+
--label-key label
46+
--apply-chat-template
47+
--rollout-shuffle
48+
--balance-data
49+
--rm-type deepscaler
50+
--num-rollout 100
51+
--rollout-batch-size 8
52+
--n-samples-per-prompt 8
53+
--rollout-max-response-len 4096
54+
--rollout-temperature 0.8
55+
--global-batch-size 64
56+
)
57+
58+
GRPO_ARGS=(
59+
--use-kl-loss
60+
--advantage-estimator grpo
61+
--kl-loss-coef 0.00
62+
--kl-loss-type low_var_kl
63+
--kl-coef 0.00
64+
--entropy-coef 0.00
65+
--eps-clip 0.2
66+
--eps-clip-high 0.28
67+
--use-tis
68+
)
69+
70+
OPTIMIZER_ARGS=(
71+
--optimizer adam
72+
--lr 1e-6
73+
--lr-decay-style constant
74+
--weight-decay 0.1
75+
--adam-beta1 0.9
76+
--adam-beta2 0.98
77+
)
78+
79+
WANDB_ARGS=(
80+
--use-wandb
81+
--wandb-project slime-dev-mcore-fsdp
82+
--wandb-group qwen3-4B-fsdp-1130-ref
83+
--wandb-key ${WANDB_API_KEY}
84+
)
85+
86+
SGLANG_ARGS=(
87+
--rollout-num-gpus-per-engine 1
88+
--sglang-mem-fraction-static 0.75
89+
--sglang-decode-log-interval 1000
90+
--sglang-chunked-prefill-size 4096
91+
--sglang-attention-backend fa3
92+
)
93+
94+
TRAIN_BACKEND_ARGS=(
95+
--train-backend fsdp
96+
--update-weight-buffer-size 536870912
97+
--gradient-checkpointing
98+
--attn-implementation flash_attention_3
99+
--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}'
100+
)
101+
102+
PERF_ARGS=(
103+
--use-dynamic-batch-size
104+
--max-tokens-per-gpu 9216
105+
)
106+
107+
MISC_ARGS=(
108+
--actor-num-nodes 1
109+
--actor-num-gpus-per-node 8
110+
--colocate
111+
--use-fault-tolerance
112+
--dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details
113+
# --fsdp-cpu-offload
114+
)
115+
116+
CUSTOM_ARGS=(
117+
--custom-config-path examples/train_infer_mismatch_helper/mis.yaml
118+
--custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_fsdp
119+
)
120+
121+
# launch the master node of ray in container - 8 GPUs for training
122+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
123+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
124+
125+
126+
RUNTIME_ENV_JSON="{
127+
\"env_vars\": {
128+
\"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
129+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
130+
}
131+
}"
132+
133+
134+
ray job submit --address="http://127.0.0.1:8265" \
135+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
136+
-- python3 train.py \
137+
${CKPT_ARGS[@]} \
138+
${ROLLOUT_ARGS[@]} \
139+
${OPTIMIZER_ARGS[@]} \
140+
${GRPO_ARGS[@]} \
141+
${WANDB_ARGS[@]} \
142+
${SGLANG_ARGS[@]} \
143+
${TRAIN_BACKEND_ARGS[@]} \
144+
${PERF_ARGS[@]} \
145+
${MISC_ARGS[@]} \
146+
${CUSTOM_ARGS[@]}
147+
148+

slime/backends/fsdp_utils/actor.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
from slime.utils.distributed_utils import get_gloo_group
2020
from slime.utils.memory_utils import clear_memory, print_memory
2121
from slime.utils.metric_utils import compute_rollout_step
22-
from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
22+
from slime.utils.misc import load_function
23+
from slime.utils.ppo_utils import (
24+
compute_approx_kl,
25+
compute_gspo_kl,
26+
compute_opsm_mask,
27+
compute_policy_loss,
28+
vanilla_tis_function,
29+
)
2330
from slime.utils.processing_utils import load_processor, load_tokenizer
2431
from slime.utils.ray_utils import Box
2532
from slime.utils.timer import Timer, inverse_timer, timer
@@ -656,26 +663,41 @@ def _has_rollout_log_probs(batch) -> bool:
656663
else None
657664
)
658665

659-
# Apply TIS before sample mean calculation
666+
# Apply off-policy correction using importance sampling if enabled
660667
if self.args.use_tis:
661-
# Apply TIS off-policy correction using importance sampling
662668
assert (
663669
has_rollout_log_probs and rollout_log_probs is not None
664-
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS"
670+
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS/MIS"
665671

666-
tis = torch.exp(old_log_probs - rollout_log_probs)
672+
train_log_probs_list = list(log_probs.split(response_lengths, dim=0))
673+
rollout_log_probs_list = list(rollout_log_probs.split(response_lengths, dim=0))
667674
ois = (-ppo_kl).exp()
668-
tis_clip = torch.clamp(
669-
tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0)
670-
)
671-
tis_clipfrac = tis_clip != tis
672-
673-
pg_loss = pg_loss * tis_clip
674-
675-
assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
676-
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
677-
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
678-
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
675+
tis_kwargs = {
676+
"args": self.args,
677+
"pg_loss": pg_loss,
678+
"train_log_probs": train_log_probs_list,
679+
"rollout_log_probs": rollout_log_probs_list,
680+
"loss_masks": loss_masks,
681+
"response_lengths": response_lengths,
682+
"cp_rank": self.cp_rank,
683+
"cp_size": self.cp_size,
684+
"cp_group": self.cp_group,
685+
}
686+
687+
if self.args.custom_tis_function_path is not None:
688+
tis_func = load_function(self.args.custom_tis_function_path)
689+
else:
690+
tis_func = vanilla_tis_function
691+
pg_loss, loss_masks, tis_metrics = tis_func(**tis_kwargs)
692+
693+
if self.args.calculate_per_token_loss:
694+
pg_loss = sum_of_token(pg_loss, response_lengths, loss_masks)
695+
pg_clipfrac = sum_of_token(pg_clipfrac, response_lengths, loss_masks)
696+
ppo_kl = sum_of_token(ppo_kl.abs(), response_lengths, loss_masks)
697+
else:
698+
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
699+
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
700+
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
679701

680702
# Only compare rollout vs. train log probs when they originate from different stages.
681703
train_rollout_logprob_abs_diff = None
@@ -722,10 +744,13 @@ def _has_rollout_log_probs(batch) -> bool:
722744
if self.args.use_opsm:
723745
reported["opsm_clipfrac"] = opsm_clipfrac
724746

725-
if self.args.use_tis and tis is not None:
726-
reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach()
747+
if self.args.use_tis and tis_metrics:
727748
reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach()
728-
reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach()
749+
for k, v in tis_metrics.items():
750+
if self.args.calculate_per_token_loss:
751+
reported[k] = sum_of_token(v, response_lengths, loss_masks).detach()
752+
else:
753+
reported[k] = sum_of_sample_mean(v, response_lengths, loss_masks).detach()
729754

730755
# Scale loss for gradient accumulation
731756
loss = loss * self.dp_size / self.args.global_batch_size
@@ -1104,3 +1129,12 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None):
11041129
fully_shard(model, **fsdp_kwargs)
11051130

11061131
return model
1132+
1133+
1134+
def sum_of_token(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor:
1135+
return sum(
1136+
[
1137+
(x_i * loss_mask_i).sum()
1138+
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
1139+
]
1140+
)

slime/backends/megatron_utils/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def policy_loss_function(
420420
Computes current log-probabilities and entropy from model logits, then
421421
calculates PPO-style clipped policy gradient loss. For GSPO, gathers
422422
full sequences via context-parallel all-gather before computing per-sample
423-
KL. Optionally applies TIS (Temporal Importance Sampling) correction and
423+
KL. Optionally applies TIS (Truncated Importance Sampling) correction and
424424
adds KL loss term if configured.
425425
426426
Args:

slime/utils/ppo_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,40 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool
662662
else:
663663
entropy = None
664664
return log_prob, entropy
665+
666+
667+
def vanilla_tis_function(
668+
args,
669+
*,
670+
pg_loss: torch.Tensor,
671+
train_log_probs: list[torch.Tensor],
672+
rollout_log_probs: list[torch.Tensor],
673+
loss_masks: list[torch.Tensor],
674+
**kwargs,
675+
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
676+
"""Apply TIS off-policy correction using importance sampling.
677+
678+
Parameters:
679+
args: Arguments containing TIS settings.
680+
pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1].
681+
train_log_probs: List of tensors containing training log-probabilities
682+
for each sequence.
683+
rollout_log_probs: List of tensors containing rollout log-probabilities
684+
for each sequence.
685+
loss_masks: List of tensors containing loss masks for each sequence.
686+
"""
687+
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
688+
old_log_probs = torch.cat(train_log_probs, dim=0)
689+
tis = torch.exp(old_log_probs - rollout_log_probs)
690+
tis_abs = (tis - 1).abs()
691+
tis_clip_low = args.tis_clip_low if args.tis_clip_low is not None else 0.1
692+
tis_clip_high = args.tis_clip if args.tis_clip is not None else 2.0
693+
tis_weights = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high)
694+
tis_clipfrac = (tis_weights != tis).float()
695+
metrics = {
696+
"tis": tis.clone().detach(),
697+
"tis_clipfrac": tis_clipfrac.clone().detach(),
698+
"tis_abs": tis_abs.clone().detach(),
699+
}
700+
pg_loss = pg_loss * tis_weights
701+
return pg_loss, loss_masks, metrics

0 commit comments

Comments
 (0)