Skip to content

Commit 2e8409c

Browse files
committed
[Feature] Add Self-Distillation Policy Optimization (SDPO) loss
Implement SDPO, a new on-policy RL algorithm for LLM post-training that uses self-distillation for dense credit assignment. Instead of relying on sparse scalar rewards, SDPO uses the model itself as a self-teacher by conditioning it on rich feedback (runtime errors, successful solutions, etc.). Key features: - SDPOLoss module with KL, reverse-KL, and Jensen-Shannon divergence options - Top-K logit distillation for memory efficiency - EMA teacher and trust-region regularization for training stability - AddFeedbackContext transform for preparing self-teacher context - Supports GRPO-style grouping to use successful rollouts as feedback Reference: "Reinforcement Learning via Self-Distillation" (Hübotter et al., 2025) https://arxiv.org/abs/2601.20802
1 parent c0f4594 commit 2e8409c

File tree

5 files changed

+1556
-0
lines changed

5 files changed

+1556
-0
lines changed

test/llm/test_llm_objectives.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GRPOLossOutput,
2525
MCAdvantage,
2626
)
27+
from torchrl.objectives.llm.sdpo import SDPOLoss, SDPOLossOutput
2728
from torchrl.objectives.llm.sft import SFTLoss
2829

2930
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -427,6 +428,305 @@ def test_cispo(self, mock_transformer_model):
427428
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
428429

429430

431+
def _mock_data_sdpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict:
432+
"""Create mock data for SDPO testing."""
433+
from transformers import AutoTokenizer
434+
435+
device = torch.device(device)
436+
437+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
438+
prompt = History(
439+
role=["system", "user"],
440+
content=["You are a useful assistant.", "What is 2+2?"],
441+
batch_size=(2,),
442+
device=device,
443+
)
444+
response = History(
445+
role=["assistant"],
446+
content=["2 + 2 = 4."],
447+
batch_size=(1,),
448+
device=device,
449+
)
450+
full_history = prompt.extend(response, inplace=False)
451+
history = ChatHistory(
452+
prompt=prompt,
453+
response=response,
454+
full=full_history,
455+
device=device,
456+
)
457+
batch_size = 1
458+
459+
# Expand history to match batch size
460+
history = history.expand((batch_size,))
461+
next_history = ChatHistory(
462+
prompt=full_history,
463+
device=device,
464+
)
465+
next_history = next_history.expand((batch_size,))
466+
467+
# Get tokens
468+
tokens_full = history.to_tokens(tokenizer)
469+
next_tokens = next_history.to_tokens(tokenizer)
470+
471+
tokens_input_ids = tokens_full.get(
472+
"full", as_padded_tensor=True, padding_side="left", padding_value=0
473+
)
474+
seq_len = tokens_input_ids.shape[-1]
475+
476+
# Create tensors
477+
reward = torch.randn(batch_size, seq_len, 1, device=device)
478+
done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device)
479+
log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device)
480+
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
481+
482+
from tensordict import MetaData
483+
484+
masks = Masks(
485+
all_attention_mask=attention_mask,
486+
all_assistant_mask=None,
487+
padded=MetaData(True),
488+
device=device,
489+
)
490+
491+
# Create teacher context (feedback-augmented prompt for self-teacher)
492+
# In real usage, this would be constructed by AddFeedbackContext transform
493+
teacher_context = {
494+
"history": history,
495+
"env_feedback": "The solution is correct.",
496+
}
497+
498+
data = TensorDict(
499+
{
500+
"history": history,
501+
"tokens": tokens_full % vocab_size,
502+
"masks": masks,
503+
"teacher_context": teacher_context,
504+
"next": {
505+
"history": next_history,
506+
"tokens": next_tokens % vocab_size,
507+
"reward": reward,
508+
"done": done,
509+
},
510+
"log_probs": log_probs,
511+
},
512+
batch_size=(batch_size,),
513+
)
514+
return data
515+
516+
517+
class TestSDPO:
518+
"""Test suite for Self-Distillation Policy Optimization (SDPO) loss."""
519+
520+
@pytest.mark.parametrize(
521+
"divergence_type", ["kl", "reverse_kl", "js"], ids=["kl", "reverse_kl", "js"]
522+
)
523+
def test_sdpo_basic(self, mock_transformer_model, divergence_type):
524+
"""Test basic SDPO loss computation with different divergence types."""
525+
vocab_size = 1024
526+
device = torch.device("cpu")
527+
528+
# Create mock model and wrap it
529+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
530+
actor_network = TransformersWrapper(
531+
model,
532+
generate=False,
533+
pad_output=True,
534+
input_mode="history",
535+
)
536+
537+
# Create loss module
538+
loss_fn = SDPOLoss(
539+
actor_network,
540+
divergence_type=divergence_type,
541+
entropy_bonus=True,
542+
entropy_coeff=0.01,
543+
)
544+
545+
# Create fake data
546+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
547+
548+
# Compute loss
549+
loss_vals = loss_fn(data)
550+
551+
# Assertions: Check output type and structure
552+
assert isinstance(
553+
loss_vals, SDPOLossOutput
554+
), f"Expected SDPOLossOutput, got {type(loss_vals)}"
555+
556+
# Check that all expected keys are present
557+
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
558+
assert hasattr(loss_vals, "divergence"), "Missing divergence"
559+
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
560+
assert hasattr(loss_vals, "entropy"), "Missing entropy"
561+
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"
562+
563+
# Check tensor shapes (all losses should be scalars after reduction)
564+
assert (
565+
loss_vals.loss_objective.shape == ()
566+
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
567+
assert (
568+
loss_vals.divergence.shape == ()
569+
), f"divergence should be scalar, got {loss_vals.divergence.shape}"
570+
assert (
571+
loss_vals.kl_approx.shape == ()
572+
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
573+
574+
# Check that losses are finite
575+
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
576+
assert torch.isfinite(loss_vals.divergence), "divergence is not finite"
577+
578+
# Divergence should be non-negative
579+
assert (
580+
loss_vals.divergence >= 0
581+
), f"divergence should be non-negative: {loss_vals.divergence}"
582+
583+
@pytest.mark.parametrize("topk", [None, 50, 100], ids=["full", "topk50", "topk100"])
584+
def test_sdpo_topk(self, mock_transformer_model, topk):
585+
"""Test SDPO with top-K logit distillation for memory efficiency."""
586+
vocab_size = 1024
587+
device = torch.device("cpu")
588+
589+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
590+
actor_network = TransformersWrapper(
591+
model,
592+
generate=False,
593+
pad_output=True,
594+
input_mode="history",
595+
)
596+
597+
# Create loss with top-K
598+
loss_fn = SDPOLoss(
599+
actor_network,
600+
divergence_type="js",
601+
topk=topk,
602+
)
603+
604+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
605+
loss_vals = loss_fn(data)
606+
607+
assert isinstance(loss_vals, SDPOLossOutput)
608+
assert torch.isfinite(loss_vals.loss_objective)
609+
assert loss_vals.divergence >= 0
610+
611+
def test_sdpo_ema_teacher(self, mock_transformer_model):
612+
"""Test SDPO with EMA teacher regularization."""
613+
vocab_size = 1024
614+
device = torch.device("cpu")
615+
616+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
617+
actor_network = TransformersWrapper(
618+
model,
619+
generate=False,
620+
pad_output=True,
621+
input_mode="history",
622+
)
623+
624+
# Create loss with EMA teacher
625+
loss_fn = SDPOLoss(
626+
actor_network,
627+
divergence_type="js",
628+
use_ema_teacher=True,
629+
ema_decay=0.99,
630+
)
631+
632+
# Check that EMA params were initialized
633+
assert loss_fn._ema_teacher_params is not None
634+
assert len(loss_fn._ema_teacher_params) > 0
635+
636+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
637+
loss_vals = loss_fn(data)
638+
639+
assert isinstance(loss_vals, SDPOLossOutput)
640+
assert torch.isfinite(loss_vals.loss_objective)
641+
642+
# Test EMA update
643+
loss_fn.update_ema_teacher()
644+
# Should still work after update
645+
loss_vals_after = loss_fn(data)
646+
assert torch.isfinite(loss_vals_after.loss_objective)
647+
648+
def test_sdpo_no_entropy(self, mock_transformer_model):
649+
"""Test SDPO without entropy bonus."""
650+
vocab_size = 1024
651+
device = torch.device("cpu")
652+
653+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
654+
actor_network = TransformersWrapper(
655+
model,
656+
generate=False,
657+
pad_output=True,
658+
input_mode="history",
659+
)
660+
661+
loss_fn = SDPOLoss(
662+
actor_network,
663+
divergence_type="js",
664+
entropy_bonus=False,
665+
)
666+
667+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
668+
loss_vals = loss_fn(data)
669+
670+
assert isinstance(loss_vals, SDPOLossOutput)
671+
assert torch.isfinite(loss_vals.loss_objective)
672+
# Entropy should be None when entropy_bonus is False
673+
assert loss_vals.entropy is None
674+
assert loss_vals.loss_entropy is None
675+
676+
677+
class TestAddFeedbackContext:
678+
"""Test suite for AddFeedbackContext transform."""
679+
680+
def test_add_feedback_direct(self):
681+
"""Test adding feedback context in direct mode."""
682+
from torchrl.envs.llm.transforms.feedback import AddFeedbackContext
683+
684+
transform = AddFeedbackContext()
685+
686+
td = TensorDict(
687+
{
688+
"query": "What is 2+2?",
689+
("text", "response"): "The answer is 5.",
690+
"env_feedback": "Wrong answer. The correct answer is 4.",
691+
("next", "reward"): torch.tensor([0.0]),
692+
("next", "done"): torch.tensor([True]),
693+
},
694+
batch_size=(),
695+
)
696+
697+
td_out = transform(td)
698+
699+
# Check that teacher_context was added
700+
assert "teacher_context" in td_out.keys()
701+
teacher_context = td_out.get("teacher_context")
702+
assert teacher_context is not None
703+
704+
def test_add_feedback_with_success(self):
705+
"""Test adding feedback context with successful rollout."""
706+
from torchrl.envs.llm.transforms.feedback import AddFeedbackContext
707+
708+
transform = AddFeedbackContext()
709+
710+
td = TensorDict(
711+
{
712+
"query": "What is 2+2?",
713+
("text", "response"): "The answer is 5.",
714+
"env_feedback": "Wrong answer.",
715+
"_successful_rollout": "The answer is 4.",
716+
("next", "reward"): torch.tensor([0.0]),
717+
("next", "done"): torch.tensor([True]),
718+
},
719+
batch_size=(),
720+
)
721+
722+
td_out = transform(td)
723+
724+
assert "teacher_context" in td_out.keys()
725+
teacher_context = td_out.get("teacher_context")
726+
# Should contain both the successful solution and feedback
727+
assert teacher_context is not None
728+
729+
430730
class TestSFT:
431731
@pytest.fixture(scope="class")
432732
def data(self):

torchrl/envs/llm/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DataLoadingPrimer,
1111
RayDataLoadingPrimer,
1212
)
13+
from .feedback import AddFeedbackContext, BuildTeacherContext
1314
from .format import TemplateTransform
1415
from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb
1516
from .policy_version import PolicyVersion
@@ -29,8 +30,10 @@
2930
)
3031

3132
__all__ = [
33+
"AddFeedbackContext",
3234
"AddThinkingPrompt",
3335
"BrowserTransform",
36+
"BuildTeacherContext",
3437
"DataLoadingPrimer",
3538
"ExecuteToolsInOrder",
3639
"IncrementalTokenizer",

0 commit comments

Comments
 (0)