Skip to content

Commit 1d4f6b2

Browse files
committed
[Feature] Add Self-Distillation Policy Optimization (SDPO) loss
Implement SDPO, a new on-policy RL algorithm for LLM post-training that uses self-distillation for dense credit assignment. Instead of relying on sparse scalar rewards, SDPO uses the model itself as a self-teacher by conditioning it on rich feedback (runtime errors, successful solutions, etc.). Key features: - SDPOLoss module with KL, reverse-KL, and Jensen-Shannon divergence options - Top-K logit distillation for memory efficiency - EMA teacher and trust-region regularization for training stability - AddFeedbackContext transform for preparing self-teacher context - Supports GRPO-style grouping to use successful rollouts as feedback Reference: "Reinforcement Learning via Self-Distillation" (Hübotter et al., 2025) https://arxiv.org/abs/2601.20802
1 parent c0f4594 commit 1d4f6b2

File tree

5 files changed

+1553
-0
lines changed

5 files changed

+1553
-0
lines changed

test/llm/test_llm_objectives.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tensordict import lazy_stack, TensorDict
1515
from torchrl._utils import logger
1616
from torchrl.data import History, LazyStackStorage, ReplayBuffer
17+
from torchrl.envs.llm.transforms.feedback import AddFeedbackContext
1718
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
1819
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1920
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
@@ -24,6 +25,7 @@
2425
GRPOLossOutput,
2526
MCAdvantage,
2627
)
28+
from torchrl.objectives.llm.sdpo import SDPOLoss, SDPOLossOutput
2729
from torchrl.objectives.llm.sft import SFTLoss
2830

2931
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -427,6 +429,301 @@ def test_cispo(self, mock_transformer_model):
427429
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
428430

429431

432+
def _mock_data_sdpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict:
433+
"""Create mock data for SDPO testing."""
434+
from transformers import AutoTokenizer
435+
436+
device = torch.device(device)
437+
438+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
439+
prompt = History(
440+
role=["system", "user"],
441+
content=["You are a useful assistant.", "What is 2+2?"],
442+
batch_size=(2,),
443+
device=device,
444+
)
445+
response = History(
446+
role=["assistant"],
447+
content=["2 + 2 = 4."],
448+
batch_size=(1,),
449+
device=device,
450+
)
451+
full_history = prompt.extend(response, inplace=False)
452+
history = ChatHistory(
453+
prompt=prompt,
454+
response=response,
455+
full=full_history,
456+
device=device,
457+
)
458+
batch_size = 1
459+
460+
# Expand history to match batch size
461+
history = history.expand((batch_size,))
462+
next_history = ChatHistory(
463+
prompt=full_history,
464+
device=device,
465+
)
466+
next_history = next_history.expand((batch_size,))
467+
468+
# Get tokens
469+
tokens_full = history.to_tokens(tokenizer)
470+
next_tokens = next_history.to_tokens(tokenizer)
471+
472+
tokens_input_ids = tokens_full.get(
473+
"full", as_padded_tensor=True, padding_side="left", padding_value=0
474+
)
475+
seq_len = tokens_input_ids.shape[-1]
476+
477+
# Create tensors
478+
reward = torch.randn(batch_size, seq_len, 1, device=device)
479+
done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device)
480+
log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device)
481+
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
482+
483+
from tensordict import MetaData
484+
485+
masks = Masks(
486+
all_attention_mask=attention_mask,
487+
all_assistant_mask=None,
488+
padded=MetaData(True),
489+
device=device,
490+
)
491+
492+
# Create teacher context (feedback-augmented prompt for self-teacher)
493+
# In real usage, this would be constructed by AddFeedbackContext transform
494+
teacher_context = {
495+
"history": history,
496+
"env_feedback": "The solution is correct.",
497+
}
498+
499+
data = TensorDict(
500+
{
501+
"history": history,
502+
"tokens": tokens_full % vocab_size,
503+
"masks": masks,
504+
"teacher_context": teacher_context,
505+
"next": {
506+
"history": next_history,
507+
"tokens": next_tokens % vocab_size,
508+
"reward": reward,
509+
"done": done,
510+
},
511+
"log_probs": log_probs,
512+
},
513+
batch_size=(batch_size,),
514+
)
515+
return data
516+
517+
518+
class TestSDPO:
519+
"""Test suite for Self-Distillation Policy Optimization (SDPO) loss."""
520+
521+
@pytest.mark.parametrize(
522+
"divergence_type", ["kl", "reverse_kl", "js"], ids=["kl", "reverse_kl", "js"]
523+
)
524+
def test_sdpo_basic(self, mock_transformer_model, divergence_type):
525+
"""Test basic SDPO loss computation with different divergence types."""
526+
vocab_size = 1024
527+
device = torch.device("cpu")
528+
529+
# Create mock model and wrap it
530+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
531+
actor_network = TransformersWrapper(
532+
model,
533+
generate=False,
534+
pad_output=True,
535+
input_mode="history",
536+
)
537+
538+
# Create loss module
539+
loss_fn = SDPOLoss(
540+
actor_network,
541+
divergence_type=divergence_type,
542+
entropy_bonus=True,
543+
entropy_coeff=0.01,
544+
)
545+
546+
# Create fake data
547+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
548+
549+
# Compute loss
550+
loss_vals = loss_fn(data)
551+
552+
# Assertions: Check output type and structure
553+
assert isinstance(
554+
loss_vals, SDPOLossOutput
555+
), f"Expected SDPOLossOutput, got {type(loss_vals)}"
556+
557+
# Check that all expected keys are present
558+
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
559+
assert hasattr(loss_vals, "divergence"), "Missing divergence"
560+
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
561+
assert hasattr(loss_vals, "entropy"), "Missing entropy"
562+
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"
563+
564+
# Check tensor shapes (all losses should be scalars after reduction)
565+
assert (
566+
loss_vals.loss_objective.shape == ()
567+
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
568+
assert (
569+
loss_vals.divergence.shape == ()
570+
), f"divergence should be scalar, got {loss_vals.divergence.shape}"
571+
assert (
572+
loss_vals.kl_approx.shape == ()
573+
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
574+
575+
# Check that losses are finite
576+
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
577+
assert torch.isfinite(loss_vals.divergence), "divergence is not finite"
578+
579+
# Divergence should be non-negative
580+
assert (
581+
loss_vals.divergence >= 0
582+
), f"divergence should be non-negative: {loss_vals.divergence}"
583+
584+
@pytest.mark.parametrize("topk", [None, 50, 100], ids=["full", "topk50", "topk100"])
585+
def test_sdpo_topk(self, mock_transformer_model, topk):
586+
"""Test SDPO with top-K logit distillation for memory efficiency."""
587+
vocab_size = 1024
588+
device = torch.device("cpu")
589+
590+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
591+
actor_network = TransformersWrapper(
592+
model,
593+
generate=False,
594+
pad_output=True,
595+
input_mode="history",
596+
)
597+
598+
# Create loss with top-K
599+
loss_fn = SDPOLoss(
600+
actor_network,
601+
divergence_type="js",
602+
topk=topk,
603+
)
604+
605+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
606+
loss_vals = loss_fn(data)
607+
608+
assert isinstance(loss_vals, SDPOLossOutput)
609+
assert torch.isfinite(loss_vals.loss_objective)
610+
assert loss_vals.divergence >= 0
611+
612+
def test_sdpo_ema_teacher(self, mock_transformer_model):
613+
"""Test SDPO with EMA teacher regularization."""
614+
vocab_size = 1024
615+
device = torch.device("cpu")
616+
617+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
618+
actor_network = TransformersWrapper(
619+
model,
620+
generate=False,
621+
pad_output=True,
622+
input_mode="history",
623+
)
624+
625+
# Create loss with EMA teacher
626+
loss_fn = SDPOLoss(
627+
actor_network,
628+
divergence_type="js",
629+
use_ema_teacher=True,
630+
ema_decay=0.99,
631+
)
632+
633+
# Check that EMA params were initialized
634+
assert loss_fn._ema_teacher_params is not None
635+
assert len(loss_fn._ema_teacher_params) > 0
636+
637+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
638+
loss_vals = loss_fn(data)
639+
640+
assert isinstance(loss_vals, SDPOLossOutput)
641+
assert torch.isfinite(loss_vals.loss_objective)
642+
643+
# Test EMA update
644+
loss_fn.update_ema_teacher()
645+
# Should still work after update
646+
loss_vals_after = loss_fn(data)
647+
assert torch.isfinite(loss_vals_after.loss_objective)
648+
649+
def test_sdpo_no_entropy(self, mock_transformer_model):
650+
"""Test SDPO without entropy bonus."""
651+
vocab_size = 1024
652+
device = torch.device("cpu")
653+
654+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
655+
actor_network = TransformersWrapper(
656+
model,
657+
generate=False,
658+
pad_output=True,
659+
input_mode="history",
660+
)
661+
662+
loss_fn = SDPOLoss(
663+
actor_network,
664+
divergence_type="js",
665+
entropy_bonus=False,
666+
)
667+
668+
data = _mock_data_sdpo(vocab_size=vocab_size, device=device)
669+
loss_vals = loss_fn(data)
670+
671+
assert isinstance(loss_vals, SDPOLossOutput)
672+
assert torch.isfinite(loss_vals.loss_objective)
673+
# Entropy should be None when entropy_bonus is False
674+
assert loss_vals.entropy is None
675+
assert loss_vals.loss_entropy is None
676+
677+
678+
class TestAddFeedbackContext:
679+
"""Test suite for AddFeedbackContext transform."""
680+
681+
def test_add_feedback_direct(self):
682+
"""Test adding feedback context in direct mode."""
683+
transform = AddFeedbackContext()
684+
685+
td = TensorDict(
686+
{
687+
"query": "What is 2+2?",
688+
("text", "response"): "The answer is 5.",
689+
"env_feedback": "Wrong answer. The correct answer is 4.",
690+
("next", "reward"): torch.tensor([0.0]),
691+
("next", "done"): torch.tensor([True]),
692+
},
693+
batch_size=(),
694+
)
695+
696+
td_out = transform(td)
697+
698+
# Check that teacher_context was added
699+
assert "teacher_context" in td_out.keys()
700+
teacher_context = td_out.get("teacher_context")
701+
assert teacher_context is not None
702+
703+
def test_add_feedback_with_success(self):
704+
"""Test adding feedback context with successful rollout."""
705+
transform = AddFeedbackContext()
706+
707+
td = TensorDict(
708+
{
709+
"query": "What is 2+2?",
710+
("text", "response"): "The answer is 5.",
711+
"env_feedback": "Wrong answer.",
712+
"_successful_rollout": "The answer is 4.",
713+
("next", "reward"): torch.tensor([0.0]),
714+
("next", "done"): torch.tensor([True]),
715+
},
716+
batch_size=(),
717+
)
718+
719+
td_out = transform(td)
720+
721+
assert "teacher_context" in td_out.keys()
722+
teacher_context = td_out.get("teacher_context")
723+
# Should contain both the successful solution and feedback
724+
assert teacher_context is not None
725+
726+
430727
class TestSFT:
431728
@pytest.fixture(scope="class")
432729
def data(self):

torchrl/envs/llm/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DataLoadingPrimer,
1111
RayDataLoadingPrimer,
1212
)
13+
from .feedback import AddFeedbackContext, BuildTeacherContext
1314
from .format import TemplateTransform
1415
from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb
1516
from .policy_version import PolicyVersion
@@ -29,8 +30,10 @@
2930
)
3031

3132
__all__ = [
33+
"AddFeedbackContext",
3234
"AddThinkingPrompt",
3335
"BrowserTransform",
36+
"BuildTeacherContext",
3437
"DataLoadingPrimer",
3538
"ExecuteToolsInOrder",
3639
"IncrementalTokenizer",

0 commit comments

Comments
 (0)