diff --git a/test/llm/test_llm_objectives.py b/test/llm/test_llm_objectives.py index 6a2341c034a..4566254b900 100644 --- a/test/llm/test_llm_objectives.py +++ b/test/llm/test_llm_objectives.py @@ -14,6 +14,7 @@ from tensordict import lazy_stack, TensorDict from torchrl._utils import logger from torchrl.data import History, LazyStackStorage, ReplayBuffer +from torchrl.envs.llm.transforms.feedback import AddFeedbackContext from torchrl.envs.llm.transforms.kl import RetrieveLogProb from torchrl.modules.llm import TransformersWrapper, vLLMWrapper from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens @@ -24,6 +25,7 @@ GRPOLossOutput, MCAdvantage, ) +from torchrl.objectives.llm.sdpo import SDPOLoss, SDPOLossOutput from torchrl.objectives.llm.sft import SFTLoss _has_transformers = importlib.util.find_spec("transformers") is not None @@ -427,6 +429,301 @@ def test_cispo(self, mock_transformer_model): ), f"clip_fraction out of range: {loss_vals.clip_fraction}" +def _mock_data_sdpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict: + """Create mock data for SDPO testing.""" + from transformers import AutoTokenizer + + device = torch.device(device) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + prompt = History( + role=["system", "user"], + content=["You are a useful assistant.", "What is 2+2?"], + batch_size=(2,), + device=device, + ) + response = History( + role=["assistant"], + content=["2 + 2 = 4."], + batch_size=(1,), + device=device, + ) + full_history = prompt.extend(response, inplace=False) + history = ChatHistory( + prompt=prompt, + response=response, + full=full_history, + device=device, + ) + batch_size = 1 + + # Expand history to match batch size + history = history.expand((batch_size,)) + next_history = ChatHistory( + prompt=full_history, + device=device, + ) + next_history = next_history.expand((batch_size,)) + + # Get tokens + tokens_full = history.to_tokens(tokenizer) + next_tokens = next_history.to_tokens(tokenizer) + + tokens_input_ids = tokens_full.get( + "full", as_padded_tensor=True, padding_side="left", padding_value=0 + ) + seq_len = tokens_input_ids.shape[-1] + + # Create tensors + reward = torch.randn(batch_size, seq_len, 1, device=device) + done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device) + log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + + from tensordict import MetaData + + masks = Masks( + all_attention_mask=attention_mask, + all_assistant_mask=None, + padded=MetaData(True), + device=device, + ) + + # Create teacher context (feedback-augmented prompt for self-teacher) + # In real usage, this would be constructed by AddFeedbackContext transform + teacher_context = { + "history": history, + "env_feedback": "The solution is correct.", + } + + data = TensorDict( + { + "history": history, + "tokens": tokens_full % vocab_size, + "masks": masks, + "teacher_context": teacher_context, + "next": { + "history": next_history, + "tokens": next_tokens % vocab_size, + "reward": reward, + "done": done, + }, + "log_probs": log_probs, + }, + batch_size=(batch_size,), + ) + return data + + +class TestSDPO: + """Test suite for Self-Distillation Policy Optimization (SDPO) loss.""" + + @pytest.mark.parametrize( + "divergence_type", ["kl", "reverse_kl", "js"], ids=["kl", "reverse_kl", "js"] + ) + def test_sdpo_basic(self, mock_transformer_model, divergence_type): + """Test basic SDPO loss computation with different divergence types.""" + vocab_size = 1024 + device = torch.device("cpu") + + # Create mock model and wrap it + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss module + loss_fn = SDPOLoss( + actor_network, + divergence_type=divergence_type, + entropy_bonus=True, + entropy_coeff=0.01, + ) + + # Create fake data + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) + + # Compute loss + loss_vals = loss_fn(data) + + # Assertions: Check output type and structure + assert isinstance( + loss_vals, SDPOLossOutput + ), f"Expected SDPOLossOutput, got {type(loss_vals)}" + + # Check that all expected keys are present + assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective" + assert hasattr(loss_vals, "divergence"), "Missing divergence" + assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx" + assert hasattr(loss_vals, "entropy"), "Missing entropy" + assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy" + + # Check tensor shapes (all losses should be scalars after reduction) + assert ( + loss_vals.loss_objective.shape == () + ), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}" + assert ( + loss_vals.divergence.shape == () + ), f"divergence should be scalar, got {loss_vals.divergence.shape}" + assert ( + loss_vals.kl_approx.shape == () + ), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}" + + # Check that losses are finite + assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite" + assert torch.isfinite(loss_vals.divergence), "divergence is not finite" + + # Divergence should be non-negative + assert ( + loss_vals.divergence >= 0 + ), f"divergence should be non-negative: {loss_vals.divergence}" + + @pytest.mark.parametrize("topk", [None, 50, 100], ids=["full", "topk50", "topk100"]) + def test_sdpo_topk(self, mock_transformer_model, topk): + """Test SDPO with top-K logit distillation for memory efficiency.""" + vocab_size = 1024 + device = torch.device("cpu") + + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss with top-K + loss_fn = SDPOLoss( + actor_network, + divergence_type="js", + topk=topk, + ) + + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) + loss_vals = loss_fn(data) + + assert isinstance(loss_vals, SDPOLossOutput) + assert torch.isfinite(loss_vals.loss_objective) + assert loss_vals.divergence >= 0 + + def test_sdpo_ema_teacher(self, mock_transformer_model): + """Test SDPO with EMA teacher regularization.""" + vocab_size = 1024 + device = torch.device("cpu") + + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss with EMA teacher + loss_fn = SDPOLoss( + actor_network, + divergence_type="js", + use_ema_teacher=True, + ema_decay=0.99, + ) + + # Check that EMA params were initialized + assert loss_fn._ema_teacher_params is not None + assert len(loss_fn._ema_teacher_params) > 0 + + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) + loss_vals = loss_fn(data) + + assert isinstance(loss_vals, SDPOLossOutput) + assert torch.isfinite(loss_vals.loss_objective) + + # Test EMA update + loss_fn.update_ema_teacher() + # Should still work after update + loss_vals_after = loss_fn(data) + assert torch.isfinite(loss_vals_after.loss_objective) + + def test_sdpo_no_entropy(self, mock_transformer_model): + """Test SDPO without entropy bonus.""" + vocab_size = 1024 + device = torch.device("cpu") + + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + loss_fn = SDPOLoss( + actor_network, + divergence_type="js", + entropy_bonus=False, + ) + + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) + loss_vals = loss_fn(data) + + assert isinstance(loss_vals, SDPOLossOutput) + assert torch.isfinite(loss_vals.loss_objective) + # Entropy should be None when entropy_bonus is False + assert loss_vals.entropy is None + assert loss_vals.loss_entropy is None + + +class TestAddFeedbackContext: + """Test suite for AddFeedbackContext transform.""" + + def test_add_feedback_direct(self): + """Test adding feedback context in direct mode.""" + transform = AddFeedbackContext() + + td = TensorDict( + { + "query": "What is 2+2?", + ("text", "response"): "The answer is 5.", + "env_feedback": "Wrong answer. The correct answer is 4.", + ("next", "reward"): torch.tensor([0.0]), + ("next", "done"): torch.tensor([True]), + }, + batch_size=(), + ) + + td_out = transform(td) + + # Check that teacher_context was added + assert "teacher_context" in td_out.keys() + teacher_context = td_out.get("teacher_context") + assert teacher_context is not None + + def test_add_feedback_with_success(self): + """Test adding feedback context with successful rollout.""" + transform = AddFeedbackContext() + + td = TensorDict( + { + "query": "What is 2+2?", + ("text", "response"): "The answer is 5.", + "env_feedback": "Wrong answer.", + "_successful_rollout": "The answer is 4.", + ("next", "reward"): torch.tensor([0.0]), + ("next", "done"): torch.tensor([True]), + }, + batch_size=(), + ) + + td_out = transform(td) + + assert "teacher_context" in td_out.keys() + teacher_context = td_out.get("teacher_context") + # Should contain both the successful solution and feedback + assert teacher_context is not None + + class TestSFT: @pytest.fixture(scope="class") def data(self): diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py index c3c2246079a..262962ec92c 100644 --- a/torchrl/envs/llm/transforms/__init__.py +++ b/torchrl/envs/llm/transforms/__init__.py @@ -10,6 +10,7 @@ DataLoadingPrimer, RayDataLoadingPrimer, ) +from .feedback import AddFeedbackContext, BuildTeacherContext from .format import TemplateTransform from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb from .policy_version import PolicyVersion @@ -29,8 +30,10 @@ ) __all__ = [ + "AddFeedbackContext", "AddThinkingPrompt", "BrowserTransform", + "BuildTeacherContext", "DataLoadingPrimer", "ExecuteToolsInOrder", "IncrementalTokenizer", diff --git a/torchrl/envs/llm/transforms/feedback.py b/torchrl/envs/llm/transforms/feedback.py new file mode 100644 index 00000000000..f90a4b2ecac --- /dev/null +++ b/torchrl/envs/llm/transforms/feedback.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections import defaultdict, deque +from typing import TYPE_CHECKING + +import torch +from tensordict import lazy_stack, NestedKey, TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.data.llm import History +from torchrl.envs.transforms.transforms import Transform + +if TYPE_CHECKING: + from transformers import AutoTokenizer + + +# Default template following the SDPO paper +DEFAULT_TEACHER_TEMPLATE = """{prompt} + +{feedback_section} +Correctly solve the original question.""" + +FEEDBACK_WITH_SOLUTION = """Correct solution: +{successful_rollout} + +The following is feedback from your unsuccessful earlier attempt: +{env_output}""" + +FEEDBACK_WITHOUT_SOLUTION = """The following is feedback from your unsuccessful earlier attempt: +{env_output}""" + +SUCCESS_ONLY = """Correct solution: +{successful_rollout}""" + + +class AddFeedbackContext(Transform): + """Adds self-teacher context to tensordict for SDPO training. + + This transform prepares the feedback-augmented context needed for Self-Distillation + Policy Optimization (SDPO). It constructs a "teacher context" by combining: + + 1. The original prompt + 2. Rich feedback from the environment (e.g., runtime errors, test failures) + 3. Optionally, a successful solution from another rollout on the same prompt + + The teacher context follows the template from the SDPO paper:: + + User: {original_prompt} + Correct solution: {successful_rollout} (if available) + Feedback from unsuccessful attempt: {env_output} + Correctly solve the original question. + Assistant: {original_response} + + This transform can be used in two modes: + + 1. **Direct mode**: Apply directly to a tensordict that already contains feedback + 2. **Grouped mode**: When attached to a replay buffer, accumulate rollouts and + find successful solutions within groups (similar to GRPO grouping) + + Args: + grpo_size (int | None, optional): If set, accumulate rollouts in groups of this + size per prompt and use successful rollouts as feedback for failed ones. + Defaults to ``None`` (direct mode). + + Keyword Args: + prompt_key (NestedKey): Key for the original prompt. Defaults to ``"query"``. + response_key (NestedKey): Key for the model's response. Defaults to ``("text", "response")``. + feedback_key (NestedKey): Key for environment feedback (e.g., error messages). + Defaults to ``"env_feedback"``. + reward_key (NestedKey): Key for reward signal (used to identify successful rollouts). + Defaults to ``("next", "reward")``. + done_key (NestedKey): Key for done signal. Defaults to ``("next", "done")``. + teacher_context_key (NestedKey): Key where the teacher context will be written. + Defaults to ``"teacher_context"``. + success_threshold (float): Reward threshold above which a rollout is considered + successful. Defaults to ``0.5``. + include_response_in_context (bool): Whether to include the original response + in the teacher context. Defaults to ``True``. + template (str | None): Custom template for constructing teacher context. + If None, uses the default SDPO template. Defaults to ``None``. + tokenizer (AutoTokenizer | None): Tokenizer for handling chat templates. + Defaults to ``None``. + verbose (bool): Whether to print verbose information. Defaults to ``False``. + + Example: + >>> # Direct mode: apply to tensordict with feedback + >>> transform = AddFeedbackContext() + >>> td["env_feedback"] = "RuntimeError: division by zero at line 73" + >>> td_with_context = transform(td) + >>> # td_with_context["teacher_context"] now contains the augmented prompt + + >>> # Grouped mode: accumulate rollouts in replay buffer + >>> rb = ReplayBuffer(storage=LazyStackStorage(100)) + >>> rb.append_transform(AddFeedbackContext(grpo_size=4)) + >>> # Rollouts are accumulated until grpo_size is reached per prompt + >>> # Successful rollouts provide feedback for failed ones + + Note: + When using grouped mode, the transform expects complete trajectories + (ending with done=True). Incomplete trajectories will raise an error. + """ + + def __init__( + self, + grpo_size: int | None = None, + *, + prompt_key: NestedKey = "query", + response_key: NestedKey = ("text", "response"), + feedback_key: NestedKey = "env_feedback", + reward_key: NestedKey = ("next", "reward"), + done_key: NestedKey = ("next", "done"), + teacher_context_key: NestedKey = "teacher_context", + success_threshold: float = 0.5, + include_response_in_context: bool = True, + template: str | None = None, + tokenizer: AutoTokenizer | None = None, + verbose: bool = False, + ): + super().__init__() + + self.grpo_size = grpo_size + self.prompt_key = prompt_key + self.response_key = response_key + self.feedback_key = feedback_key + self.reward_key = reward_key + self.done_key = done_key + self.teacher_context_key = teacher_context_key + self.success_threshold = success_threshold + self.include_response_in_context = include_response_in_context + self.template = template if template is not None else DEFAULT_TEACHER_TEMPLATE + self.tokenizer = tokenizer + self.verbose = verbose + + # Storage for grouped mode + if grpo_size is not None: + self.queues = defaultdict(lambda: deque(maxlen=grpo_size)) + else: + self.queues = None + + self.in_keys = [prompt_key, response_key, feedback_key, reward_key, done_key] + self.out_keys = [teacher_context_key] + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Apply transform in forward direction (for env steps).""" + return self._add_teacher_context(tensordict) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Apply transform in inverse direction (for replay buffer writes). + + In grouped mode, this accumulates rollouts and processes them when + enough have been collected for each prompt. + """ + if self.grpo_size is None: + # Direct mode: just add teacher context + return self._add_teacher_context(tensordict) + + # Grouped mode: accumulate and process + return self._process_grouped(tensordict) + + def _add_teacher_context(self, tensordict: TensorDictBase) -> TensorDictBase: + """Add teacher context to a single tensordict. + + Constructs the feedback-augmented prompt for the self-teacher. + """ + # Get components + prompt = self._get_value(tensordict, self.prompt_key, default="") + response = self._get_value(tensordict, self.response_key, default="") + env_feedback = self._get_value(tensordict, self.feedback_key, default=None) + successful_rollout = tensordict.get("_successful_rollout", None) + + # Build feedback section + if env_feedback is not None and successful_rollout is not None: + feedback_section = FEEDBACK_WITH_SOLUTION.format( + successful_rollout=successful_rollout, + env_output=env_feedback, + ) + elif env_feedback is not None: + feedback_section = FEEDBACK_WITHOUT_SOLUTION.format( + env_output=env_feedback, + ) + elif successful_rollout is not None: + feedback_section = SUCCESS_ONLY.format( + successful_rollout=successful_rollout, + ) + else: + # No feedback available - teacher context is just the prompt + feedback_section = "" + + # Build full teacher context + if feedback_section: + teacher_context = self.template.format( + prompt=prompt, + feedback_section=feedback_section, + ) + else: + teacher_context = prompt + + # Optionally include response + if self.include_response_in_context and response: + # The response is appended as what the assistant said + # (the teacher will re-evaluate these tokens) + teacher_context_with_response = { + "prompt": teacher_context, + "response": response, + } + tensordict.set(self.teacher_context_key, teacher_context_with_response) + else: + tensordict.set(self.teacher_context_key, teacher_context) + + return tensordict + + def _get_value(self, tensordict: TensorDictBase, key: NestedKey, default=None): + """Get a value from tensordict, handling nested keys and defaults.""" + try: + value = tensordict.get(key, default=default) + if value is None: + return default + # Handle various types + if isinstance(value, torch.Tensor): + if value.numel() == 1: + return value.item() + return value + return value + except (KeyError, AttributeError): + return default + + def _process_grouped(self, tensordict: TensorDictBase) -> TensorDictBase | None: + """Process tensordict in grouped mode. + + Accumulates rollouts per prompt and processes when grpo_size is reached. + Successful rollouts provide feedback for failed ones. + """ + if self.verbose: + torchrl_logger.info( + f"Invoking AddFeedbackContext.\nData size: {tensordict.shape}.\n" + f"Current queue size: {len(self.queues)}.\n" + f"Total queue content: {sum(len(q) for q in self.queues.values())}" + ) + + # Handle different input dimensions + if tensordict.ndim == 1: + # Check how many done states we have + done = tensordict.get(self.done_key, None) + if done is None: + done = torch.ones(1, dtype=torch.bool) # Assume single complete traj + + num_done = done.sum() if isinstance(done, torch.Tensor) else 1 + if num_done > 1: + # Split into individual trajectories + done_idx = done.nonzero(as_tuple=True)[0] + 1 + splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff() + tensordicts = tensordict.split(splits.tolist()) + tensordicts = [self._process_grouped(td) for td in tensordicts] + tensordicts = [td for td in tensordicts if td is not None] + return torch.cat(tensordicts) if tensordicts else None + + # Single trajectory + if tensordict.ndim > 0: + last_done = tensordict[-1].get(self.done_key, None) + else: + last_done = tensordict.get(self.done_key, None) + + if last_done is not None and not last_done.all(): + raise RuntimeError("Expected the trajectory to be done.") + + # Get prompt for grouping + if tensordict.ndim > 0: + prompt = self._get_value(tensordict[0], self.prompt_key, default="") + else: + prompt = self._get_value(tensordict, self.prompt_key, default="") + + if not isinstance(prompt, str): + # Convert to string for hashing + prompt = str(prompt) + + self.queues[prompt].append(tensordict) + + if len(self.queues[prompt]) == self.grpo_size: + if self.verbose: + torchrl_logger.info(f"Processing group for {prompt[:50]}...") + + # Process the group + tds = list(self.queues[prompt]) + del self.queues[prompt] + + # Find successful rollouts + successful_rollouts = [] + for td in tds: + reward = self._get_final_reward(td) + if reward is not None and reward > self.success_threshold: + response = self._get_value(td, self.response_key, default=None) + if response is not None: + successful_rollouts.append(response) + + # Add feedback context to each trajectory + processed_tds = [] + for td in tds: + td = td.clone(False) + reward = self._get_final_reward(td) + is_successful = ( + reward is not None and reward > self.success_threshold + ) + + if not is_successful and successful_rollouts: + # Use a successful rollout as feedback + td.set("_successful_rollout", successful_rollouts[0]) + + self._add_teacher_context(td) + processed_tds.append(td) + + # Stack and return + return lazy_stack(processed_tds) + + return None # Not enough rollouts yet + + elif tensordict.ndim > 2: + # Flatten extra dimensions + tensordict = tensordict.flatten(0, -2) + + # Process each trajectory + trajs = tensordict.unbind(0) + result = [] + for traj in trajs: + td_out = self._process_grouped(traj) + if td_out is None: + continue + result.append(td_out) + + if result: + return torch.cat(result, 0) + return None + + def _get_final_reward(self, tensordict: TensorDictBase) -> float | None: + """Get the final reward from a trajectory.""" + try: + if tensordict.ndim > 0: + reward = tensordict[-1].get(self.reward_key, None) + else: + reward = tensordict.get(self.reward_key, None) + + if reward is None: + return None + + if isinstance(reward, torch.Tensor): + # Get the final value + if reward.numel() > 1: + reward = reward[-1] + return float(reward.item()) + return float(reward) + except (KeyError, AttributeError, TypeError): + return None + + def reset_queues(self): + """Reset the accumulation queues (for grouped mode).""" + if self.queues is not None: + self.queues.clear() + + +class BuildTeacherContext(Transform): + """Builds teacher context from History objects for SDPO. + + This transform is designed for chat-based LLM training where prompts and + responses are represented as History objects. It constructs the teacher + context by appending feedback information to the conversation history. + + Args: + tokenizer (AutoTokenizer): Tokenizer for applying chat templates. + + Keyword Args: + history_key (NestedKey): Key for the chat history. Defaults to ``"history"``. + feedback_key (NestedKey): Key for environment feedback. Defaults to ``"env_feedback"``. + success_response_key (NestedKey): Key for successful response (if available). + Defaults to ``"successful_response"``. + teacher_context_key (NestedKey): Key where teacher context will be written. + Defaults to ``"teacher_context"``. + chat_template_name (str | None): Name of the chat template to use. + Defaults to ``None``. + + Example: + >>> transform = BuildTeacherContext(tokenizer) + >>> td["history"] = ChatHistory(prompt=prompt_history, response=response_history) + >>> td["env_feedback"] = "Error: index out of bounds" + >>> td_out = transform(td) + >>> # td_out["teacher_context"] now contains the augmented history + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + *, + history_key: NestedKey = "history", + feedback_key: NestedKey = "env_feedback", + success_response_key: NestedKey = "successful_response", + teacher_context_key: NestedKey = "teacher_context", + chat_template_name: str | None = None, + ): + super().__init__() + + self.tokenizer = tokenizer + self.history_key = history_key + self.feedback_key = feedback_key + self.success_response_key = success_response_key + self.teacher_context_key = teacher_context_key + self.chat_template_name = chat_template_name + + self.in_keys = [history_key, feedback_key, success_response_key] + self.out_keys = [teacher_context_key] + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Build teacher context from history and feedback.""" + history = tensordict.get(self.history_key, None) + if history is None: + return tensordict + + feedback = tensordict.get(self.feedback_key, None) + success_response = tensordict.get(self.success_response_key, None) + + # Build feedback message + feedback_parts = [] + if success_response is not None: + feedback_parts.append(f"Correct solution:\n{success_response}") + if feedback is not None: + feedback_parts.append( + f"The following is feedback from your unsuccessful earlier attempt:\n{feedback}" + ) + if feedback_parts: + feedback_parts.append("Correctly solve the original question.") + + feedback_message = "\n\n".join(feedback_parts) + + # Create teacher context by extending history with feedback + if hasattr(history, "prompt"): + # ChatHistory-like object + prompt_history = history.prompt + + if feedback_message: + # Append feedback as a user message + feedback_turn = History( + role=["user"], + content=[feedback_message], + batch_size=(1,), + ) + teacher_prompt = prompt_history.extend(feedback_turn, inplace=False) + else: + teacher_prompt = prompt_history + + # Create teacher context dict with prompt and response + teacher_context = { + "history": teacher_prompt, + "original_response": history.response + if hasattr(history, "response") + else None, + } + else: + # Plain History object + if feedback_message: + feedback_turn = History( + role=["user"], + content=[feedback_message], + batch_size=(1,), + ) + teacher_context = history.extend(feedback_turn, inplace=False) + else: + teacher_context = history + + tensordict.set(self.teacher_context_key, teacher_context) + return tensordict diff --git a/torchrl/objectives/llm/__init__.py b/torchrl/objectives/llm/__init__.py index f53172ba990..c7101c9266b 100644 --- a/torchrl/objectives/llm/__init__.py +++ b/torchrl/objectives/llm/__init__.py @@ -14,6 +14,7 @@ LLMLossOutput, MCAdvantage, ) +from .sdpo import SDPOLoss, SDPOLossOutput from .sft import SFTLoss, SFTLossOutput __all__ = [ @@ -25,6 +26,8 @@ "GRPOLossOutput", "LLMLossOutput", "MCAdvantage", + "SDPOLoss", + "SDPOLossOutput", "SFTLoss", "SFTLossOutput", ] diff --git a/torchrl/objectives/llm/sdpo.py b/torchrl/objectives/llm/sdpo.py new file mode 100644 index 00000000000..77f65d36ea4 --- /dev/null +++ b/torchrl/objectives/llm/sdpo.py @@ -0,0 +1,779 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import contextlib +import math +from dataclasses import dataclass +from typing import Literal + +import torch +from tensordict import ( + is_tensor_collection, + NestedKey, + TensorClass, + TensorDict, + TensorDictBase, +) +from tensordict.nn import ( + CompositeDistribution, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + set_composite_lp_aggregate, +) +from tensordict.utils import expand_as_right +from torch import distributions as d + +from torchrl._utils import logger as torchrl_logger, VERBOSE +from torchrl.modules.llm import LLMWrapperBase +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import _reduce, _sum_td_features + + +class SDPOLossOutput(TensorClass["nocast"]): + """SDPO Loss Output. + + This class defines the output structure for Self-Distillation Policy Optimization + (SDPO) loss computation. + + Attributes: + loss_objective: The main policy objective loss. + divergence: The divergence between student and teacher distributions. + kl_approx: Approximate KL divergence for logging. + entropy: Policy entropy (if entropy_bonus is enabled). + loss_entropy: Entropy loss term (if entropy_bonus is enabled). + loss_kl_to_ref: KL divergence loss to reference policy (if kl_to_ref_coeff is set). + kl_to_ref: KL divergence to reference policy for logging. + """ + + loss_objective: torch.Tensor + divergence: torch.Tensor + kl_approx: torch.Tensor + entropy: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + loss_kl_to_ref: torch.Tensor | None = None + kl_to_ref: torch.Tensor | None = None + + +class SDPOLoss(LossModule): + """Self-Distillation Policy Optimization (SDPO) loss. + + SDPO is an on-policy algorithm that uses self-distillation for credit assignment + in LLM post-training. Instead of relying on sparse scalar rewards, SDPO uses the + model itself as a self-teacher by conditioning it on rich feedback (runtime errors, + successful solutions, etc.). + + The core idea is that the same model can be used in two roles: + - **Student**: The policy generating responses without seeing feedback + - **Self-Teacher**: The same policy conditioned on feedback, which can retrospectively + identify mistakes and assign dense, per-token credit + + The loss minimizes the KL divergence (or Jensen-Shannon divergence) between the + student and self-teacher distributions:: + + L_SDPO = sum_t KL(pi_theta(·|x, y_{>> from torchrl.objectives.llm import SDPOLoss + >>> from torchrl.modules.llm import TransformersWrapper + >>> + >>> # Wrap model + >>> actor = TransformersWrapper(model, tokenizer=tokenizer, generate=False, pad_output=True) + >>> + >>> # Create SDPO loss with Jensen-Shannon divergence and EMA teacher + >>> loss_fn = SDPOLoss( + ... actor_network=actor, + ... divergence_type="js", + ... topk=100, + ... use_ema_teacher=True, + ... ema_decay=0.99, + ... ) + >>> + >>> # Compute loss (tensordict must include teacher_context) + >>> loss_output = loss_fn(tensordict) + >>> loss_output.loss_objective.backward() + + Complete training loop with feedback: + + >>> from torchrl.envs.llm.transforms import AddFeedbackContext + >>> from torchrl.data import ReplayBuffer, LazyStackStorage + >>> + >>> # Set up replay buffer with feedback transform + >>> rb = ReplayBuffer(storage=LazyStackStorage(1000)) + >>> rb.append_transform(AddFeedbackContext(grpo_size=4)) + >>> + >>> # Training loop + >>> for epoch in range(num_epochs): + ... # Collect rollouts with environment feedback + ... rollouts = collector.collect() # Should contain "env_feedback" key + ... rb.extend(rollouts) + ... + ... # Sample batch and compute loss + ... batch = rb.sample(batch_size) + ... loss_output = loss_fn(batch) + ... + ... # Backward pass + ... optimizer.zero_grad() + ... loss_output.loss_objective.backward() + ... optimizer.step() + ... + ... # Update EMA teacher + ... loss_fn.update_ema_teacher() + + Note: + The input tensordict should contain a "teacher_context" key with the feedback-augmented + context for the self-teacher. This can be prepared using the ``AddFeedbackContext`` + transform from ``torchrl.envs.llm.transforms``. + """ + + actor_network: LLMWrapperBase + output_type: type[SDPOLossOutput] = SDPOLossOutput + + @dataclass + class _AcceptedKeys(LossModule._AcceptedKeys): + """Maintains default values for all configurable tensordict keys. + + Attributes: + action: Key for action tokens. Defaults to ``("tokens", "full")``. + sample_log_prob: Key for student log probabilities. Defaults to ``("log_probs", "full")``. + teacher_context: Key for the self-teacher context (feedback-augmented prompt). + Defaults to ``"teacher_context"``. + ref_log_probs: Key for reference policy log probabilities (for KL penalty). + Defaults to ``("next", "ref_log_probs", "full")``. + """ + + action: NestedKey = ("tokens", "full") + sample_log_prob: NestedKey = ("log_probs", "full") + teacher_context: NestedKey = "teacher_context" + ref_log_probs: NestedKey = ("next", "ref_log_probs", "full") + + @property + def tensor_keys(self) -> _AcceptedKeys: + """Access the tensordict key configuration for this loss.""" + return self._tensor_keys + + def __init__( + self, + actor_network: LLMWrapperBase | None = None, + *, + divergence_type: Literal["kl", "reverse_kl", "js"] = "js", + topk: int | None = None, + use_ema_teacher: bool = False, + ema_decay: float = 0.99, + trust_region_alpha: float | None = None, + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_coeff: float = 0.01, + kl_to_ref_coeff: float | None = None, + aggregation: str = "token_mean", + reduction: str | None = None, + masking_strategy: Literal["sft", "rlhf", "generic"] = "sft", + device: torch.device | None = None, + ): + super().__init__() + + self.actor_network = actor_network + self.divergence_type = divergence_type + self.topk = topk + self.use_ema_teacher = use_ema_teacher + self.ema_decay = ema_decay + self.trust_region_alpha = trust_region_alpha + self.entropy_bonus = entropy_bonus + self.samples_mc_entropy = samples_mc_entropy + self.entropy_coeff = entropy_coeff + self.kl_to_ref_coeff = kl_to_ref_coeff + self.aggregation = aggregation + self.reduction = reduction if reduction is not None else "mean" + self.masking_strategy = masking_strategy + + # Determine device + if device is None: + try: + device = next(self.parameters()).device + except (AttributeError, StopIteration): + device = getattr( + torch, "get_default_device", lambda: torch.device("cpu") + )() + self._device = device + + # Initialize EMA teacher parameters if requested + if use_ema_teacher and actor_network is not None: + self._init_ema_teacher() + else: + self._ema_teacher_params = None + + # Store reference teacher log-probs for trust-region (populated on first forward) + self._ref_teacher_logprobs = None + + # Set default keys + self.set_keys( + sample_log_prob=("log_probs", "full"), + action=("tokens", "full"), + ) + self._set_in_keys() + + def _init_ema_teacher(self): + """Initialize EMA teacher parameters as a copy of actor network parameters.""" + if hasattr(self.actor_network, "parameters"): + # Create a deep copy of parameters for EMA + self._ema_teacher_params = { + name: param.detach().clone() + for name, param in self.actor_network.named_parameters() + } + else: + self._ema_teacher_params = None + + def update_ema_teacher(self): + """Update EMA teacher parameters with current actor parameters. + + Should be called after each optimization step when using EMA teacher. + The update rule is: theta_ema = decay * theta_ema + (1 - decay) * theta + """ + if self._ema_teacher_params is None: + return + + with torch.no_grad(): + for name, param in self.actor_network.named_parameters(): + if name in self._ema_teacher_params: + self._ema_teacher_params[name].mul_(self.ema_decay).add_( + param.data, alpha=1 - self.ema_decay + ) + + def _set_in_keys(self): + keys = [] + if getattr(self, "actor_network", None) is not None and hasattr( + self.actor_network, "in_keys" + ): + in_keys = self.actor_network.in_keys + if isinstance(in_keys, (list, tuple)): + keys.extend(in_keys) + keys.append(self.tensor_keys.action) + keys.append(self.tensor_keys.sample_log_prob) + keys.append(self.tensor_keys.teacher_context) + keys.append(self.tensor_keys.ref_log_probs) + self._in_keys = list(dict.fromkeys(keys)) + + @property + def in_keys(self): + if getattr(self, "_in_keys", None) is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if getattr(self, "_out_keys", None) is None: + keys = ["loss_objective", "divergence", "kl_approx"] + if self.entropy_bonus: + keys.extend(["entropy", "loss_entropy"]) + if self.kl_to_ref_coeff is not None: + keys.extend(["loss_kl_to_ref", "kl_to_ref"]) + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + def _get_student_log_prob(self, tensordict: TensorDictBase): + """Get log probabilities from the student (policy without feedback).""" + if isinstance( + self.actor_network, + (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), + ) or hasattr(self.actor_network, "get_dist"): + # Use the specified masking strategy + if self.masking_strategy == "sft" and hasattr( + self.actor_network, "_get_sft_dist" + ): + dist = self.actor_network._get_sft_dist(tensordict) + elif self.masking_strategy == "rlhf" and hasattr( + self.actor_network, "_get_rlhf_dist" + ): + dist = self.actor_network._get_rlhf_dist(tensordict) + elif self.masking_strategy == "generic" and hasattr( + self.actor_network, "_get_generic_dist" + ): + dist = self.actor_network._get_generic_dist(tensordict) + elif hasattr(self.actor_network, "get_dist"): + dist = self.actor_network.get_dist(tensordict, logits_key="logits") + else: + raise NotImplementedError( + f"Actor network must have get_dist method or the appropriate method for " + f"masking strategy '{self.masking_strategy}'." + ) + + action = tensordict.get( + self.tensor_keys.action, + as_padded_tensor=True, + padding_side="left", + padding_value=-100, + ) + log_prob = dist.log_prob(action) + # Also get logits for divergence computation + logits = tensordict.get("logits", None) + else: + raise NotImplementedError( + "Only probabilistic modules from tensordict.nn are currently supported." + ) + return log_prob, dist, logits + + def _get_teacher_log_prob( + self, tensordict: TensorDictBase, teacher_context: TensorDictBase + ): + """Get log probabilities from the self-teacher (policy with feedback context). + + The teacher sees the same response but with additional context (feedback). + """ + # Build teacher input by incorporating feedback context + teacher_td = tensordict.clone(False) + + # Merge teacher context into the tensordict + # The teacher_context should contain the reprompted history/tokens + if teacher_context is not None: + teacher_td.update(teacher_context) + + # Use EMA parameters if available + if self._ema_teacher_params is not None: + # Temporarily swap parameters + original_params = {} + for name, param in self.actor_network.named_parameters(): + original_params[name] = param.data.clone() + param.data.copy_(self._ema_teacher_params[name]) + + try: + # Get teacher distribution + if isinstance( + self.actor_network, + (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), + ) or hasattr(self.actor_network, "get_dist"): + if self.masking_strategy == "sft" and hasattr( + self.actor_network, "_get_sft_dist" + ): + teacher_dist = self.actor_network._get_sft_dist(teacher_td) + elif self.masking_strategy == "rlhf" and hasattr( + self.actor_network, "_get_rlhf_dist" + ): + teacher_dist = self.actor_network._get_rlhf_dist(teacher_td) + elif self.masking_strategy == "generic" and hasattr( + self.actor_network, "_get_generic_dist" + ): + teacher_dist = self.actor_network._get_generic_dist(teacher_td) + elif hasattr(self.actor_network, "get_dist"): + teacher_dist = self.actor_network.get_dist( + teacher_td, logits_key="logits" + ) + else: + raise NotImplementedError( + "Actor network must have get_dist method." + ) + + action = tensordict.get( + self.tensor_keys.action, + as_padded_tensor=True, + padding_side="left", + padding_value=-100, + ) + teacher_log_prob = teacher_dist.log_prob(action) + teacher_logits = teacher_td.get("logits", None) + else: + raise NotImplementedError( + "Only probabilistic modules from tensordict.nn are currently supported." + ) + finally: + # Restore original parameters if we swapped them + if self._ema_teacher_params is not None: + for name, param in self.actor_network.named_parameters(): + param.data.copy_(original_params[name]) + + return teacher_log_prob, teacher_dist, teacher_logits + + def _compute_divergence( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute divergence between student and teacher distributions. + + Args: + student_logits: Logits from student [batch, seq_len, vocab_size] + teacher_logits: Logits from teacher [batch, seq_len, vocab_size] + mask: Attention mask [batch, seq_len] + + Returns: + Per-token divergence [batch, seq_len] + """ + # Apply top-K filtering if specified + if self.topk is not None: + divergence = self._topk_divergence(student_logits, teacher_logits) + else: + divergence = self._full_divergence(student_logits, teacher_logits) + + # Apply mask if provided + if mask is not None: + divergence = torch.where( + expand_as_right(mask, divergence), divergence, divergence.new_zeros(()) + ) + + return divergence + + def _full_divergence( + self, student_logits: torch.Tensor, teacher_logits: torch.Tensor + ) -> torch.Tensor: + """Compute full divergence over all vocabulary tokens.""" + student_log_probs = torch.log_softmax(student_logits, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1) + + student_probs = student_log_probs.exp() + teacher_probs = teacher_log_probs.exp() + + if self.divergence_type == "kl": + # KL(student || teacher) = sum_i p_student(i) * log(p_student(i) / p_teacher(i)) + divergence = (student_probs * (student_log_probs - teacher_log_probs)).sum( + -1 + ) + elif self.divergence_type == "reverse_kl": + # KL(teacher || student) = sum_i p_teacher(i) * log(p_teacher(i) / p_student(i)) + divergence = (teacher_probs * (teacher_log_probs - student_log_probs)).sum( + -1 + ) + elif self.divergence_type == "js": + # Jensen-Shannon divergence + # JS(p, q) = 0.5 * KL(p || m) + 0.5 * KL(q || m) where m = 0.5 * (p + q) + m_log_probs = torch.logaddexp( + student_log_probs, teacher_log_probs + ) - math.log(2) + kl_student_m = (student_probs * (student_log_probs - m_log_probs)).sum(-1) + kl_teacher_m = (teacher_probs * (teacher_log_probs - m_log_probs)).sum(-1) + divergence = 0.5 * (kl_student_m + kl_teacher_m) + else: + raise ValueError(f"Unknown divergence type: {self.divergence_type}") + + return divergence + + def _topk_divergence( + self, student_logits: torch.Tensor, teacher_logits: torch.Tensor + ) -> torch.Tensor: + """Compute divergence over top-K logits only for memory efficiency. + + This approximates the full divergence by only considering the top-K most + likely tokens according to the student, plus a tail term capturing the + remaining probability mass. + """ + k = self.topk + + # Get top-K indices from student + student_log_probs_full = torch.log_softmax(student_logits, dim=-1) + topk_log_probs, topk_idx = student_log_probs_full.topk(k, dim=-1) + + # Gather corresponding teacher log-probs + teacher_log_probs_full = torch.log_softmax(teacher_logits, dim=-1) + teacher_topk_log_probs = teacher_log_probs_full.gather(-1, topk_idx) + + # Convert to probabilities for top-K + student_topk_probs = topk_log_probs.exp() + teacher_topk_probs = teacher_topk_log_probs.exp() + + # Compute tail probabilities + student_tail_prob = 1.0 - student_topk_probs.sum(-1, keepdim=True) + teacher_tail_prob = 1.0 - teacher_topk_probs.sum(-1, keepdim=True) + + # Clamp to avoid log(0) + student_tail_prob = student_tail_prob.clamp(min=1e-10) + teacher_tail_prob = teacher_tail_prob.clamp(min=1e-10) + + if self.divergence_type == "kl": + # KL over top-K + kl_topk = ( + student_topk_probs * (topk_log_probs - teacher_topk_log_probs) + ).sum(-1) + # Tail term + kl_tail = student_tail_prob.squeeze(-1) * ( + student_tail_prob.log().squeeze(-1) + - teacher_tail_prob.log().squeeze(-1) + ) + divergence = kl_topk + kl_tail + + elif self.divergence_type == "reverse_kl": + kl_topk = ( + teacher_topk_probs * (teacher_topk_log_probs - topk_log_probs) + ).sum(-1) + kl_tail = teacher_tail_prob.squeeze(-1) * ( + teacher_tail_prob.log().squeeze(-1) + - student_tail_prob.log().squeeze(-1) + ) + divergence = kl_topk + kl_tail + + elif self.divergence_type == "js": + # JS over top-K + m_topk_log_probs = torch.logaddexp( + topk_log_probs, teacher_topk_log_probs + ) - math.log(2) + kl_student_m = ( + student_topk_probs * (topk_log_probs - m_topk_log_probs) + ).sum(-1) + kl_teacher_m = ( + teacher_topk_probs * (teacher_topk_log_probs - m_topk_log_probs) + ).sum(-1) + + # JS tail term + m_tail_prob = 0.5 * (student_tail_prob + teacher_tail_prob) + m_tail_log_prob = m_tail_prob.log() + kl_s_tail = student_tail_prob * (student_tail_prob.log() - m_tail_log_prob) + kl_t_tail = teacher_tail_prob * (teacher_tail_prob.log() - m_tail_log_prob) + + divergence = 0.5 * (kl_student_m + kl_teacher_m) + 0.5 * ( + kl_s_tail + kl_t_tail + ).squeeze(-1) + else: + raise ValueError(f"Unknown divergence type: {self.divergence_type}") + + return divergence + + def _apply_trust_region( + self, teacher_log_probs: torch.Tensor, ref_teacher_log_probs: torch.Tensor + ) -> torch.Tensor: + """Apply trust-region regularization to teacher log-probs. + + Interpolates between reference and current teacher: + q*(y) ∝ exp((1-α)*log q_ref + α*log q_current) + """ + if self.trust_region_alpha is None: + return teacher_log_probs + + alpha = self.trust_region_alpha + # Interpolate in log-space + interpolated = (1 - alpha) * ref_teacher_log_probs + alpha * teacher_log_probs + # Renormalize (log-sum-exp trick) + interpolated = interpolated - interpolated.logsumexp(dim=-1, keepdim=True) + return interpolated + + def _get_entropy( + self, dist: d.Distribution, adv_shape: torch.Size + ) -> torch.Tensor | TensorDict: + """Compute entropy of the policy distribution.""" + try: + entropy = dist.entropy() + if not entropy.isfinite().all(): + del entropy + if VERBOSE: + torchrl_logger.info( + "Entropy is not finite. Using Monte Carlo sampling." + ) + raise NotImplementedError + except NotImplementedError: + if VERBOSE: + torchrl_logger.warning( + f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling." + ) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy) and entropy.batch_size != adv_shape: + entropy.batch_size = adv_shape + return entropy.unsqueeze(-1) + + def _kl_to_ref( + self, + tensordict: TensorDictBase, + mask: torch.Tensor | None = None, + dist: d.Distribution | None = None, + ): + """Compute KL divergence to reference policy.""" + ref_log_prob = tensordict.get( + self.tensor_keys.ref_log_probs, + as_padded_tensor=True, + padding_side="left", + padding_value=0.0, + ) + if ref_log_prob is None: + raise KeyError( + f"Couldn't find the ref log-prob {self.tensor_keys.ref_log_probs} in the input data." + ) + ref_log_prob = ref_log_prob.squeeze(-1) + + cur_log_prob = tensordict.get("_cur_log_prob") + if cur_log_prob.shape != ref_log_prob.shape: + raise ValueError( + f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}" + ) + if mask is not None: + ref_log_prob = torch.where( + expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0 + ) + cur_log_prob = torch.where( + expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0 + ) + diff = ref_log_prob - cur_log_prob + kl_penalty = (diff.expm1() - diff).mean() + return self.kl_to_ref_coeff * kl_penalty, kl_penalty + + def _aggregate_loss_value( + self, value: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Aggregate a per-token loss tensor using the configured strategy.""" + if self.aggregation == "none" or self.reduction == "none": + mask_exp = expand_as_right(mask, value) + return torch.where(mask_exp, value, value.new_zeros(()).expand_as(value)) + + if self.aggregation == "prompt_mean": + mask_exp = expand_as_right(mask, value).to(value.dtype) + token_sum = (value * mask_exp).sum(dim=-2, keepdim=False) + token_count = mask_exp.sum(dim=-2, keepdim=False).clamp_min(1.0) + sample_mean = token_sum / token_count + return sample_mean.mean(dim=0, keepdim=False) + + # token_mean (global masked mean) + return _reduce(value, reduction="mean", mask=mask).squeeze(-1) + + def forward(self, tensordict: TensorDictBase) -> SDPOLossOutput: + """Compute the SDPO loss. + + Args: + tensordict: Input data containing: + - action tokens (self.tensor_keys.action) + - student log probabilities (self.tensor_keys.sample_log_prob) + - teacher context (self.tensor_keys.teacher_context) + + Returns: + SDPOLossOutput containing loss_objective, divergence, and other metrics. + """ + tensordict = tensordict.copy() + + # Get teacher context (feedback-augmented prompt) + teacher_context = tensordict.get(self.tensor_keys.teacher_context, None) + + # Run forward pass to get student logits + with torch.no_grad() if not self.training else contextlib.nullcontext(): + self.actor_network(tensordict) + + # Get student distribution and log-probs + student_log_prob, student_dist, student_logits = self._get_student_log_prob( + tensordict + ) + tensordict.set("_cur_log_prob", student_log_prob) + + # Get teacher log-probs (with stopgrad on teacher) + with torch.no_grad(): + teacher_log_prob, teacher_dist, teacher_logits = self._get_teacher_log_prob( + tensordict, teacher_context + ) + + # Get mask from distribution + mask = student_dist.mask + + # Compute divergence loss + if student_logits is not None and teacher_logits is not None: + # Use logit-level divergence + divergence = self._compute_divergence(student_logits, teacher_logits, mask) + else: + # Fall back to log-prob based approximation + # This is less accurate but works when logits aren't available + divergence = (student_log_prob - teacher_log_prob).abs() + if mask is not None: + divergence = torch.where( + expand_as_right(mask, divergence), + divergence, + divergence.new_zeros(()), + ) + + # Compute kl_approx for logging (log-prob level) + kl_approx = (student_log_prob - teacher_log_prob).unsqueeze(-1) + + # Build output + td_out = TensorDict( + { + "loss_objective": divergence.unsqueeze(-1), + "divergence": divergence.detach().mean(), + "kl_approx": kl_approx.detach().mean(), + } + ) + + # Add entropy bonus if requested + if self.entropy_bonus: + entropy = self._get_entropy(student_dist, adv_shape=student_log_prob.shape) + if is_tensor_collection(entropy): + td_out.set("composite_entropy", entropy.detach()) + entropy = _sum_td_features(entropy) + td_out.set("entropy", entropy.detach().mean()) + td_out.set("loss_entropy", -self.entropy_coeff * entropy) + + # Add KL to reference if requested + if self.kl_to_ref_coeff is not None and self.kl_to_ref_coeff > 0: + loss_kl, kl_penalty = self._kl_to_ref( + tensordict, mask=mask, dist=student_dist + ) + td_out["loss_kl_to_ref"] = loss_kl + td_out["kl_to_ref"] = kl_penalty.detach() + + # Aggregate loss terms + for key in list(td_out.keys()): + if isinstance(key, tuple) or not isinstance(key, str): + continue + if key.startswith("loss_"): + val = td_out.get(key) + td_out.set(key, self._aggregate_loss_value(val, mask)) + + # Clean up + del tensordict["_cur_log_prob"] + + return self.output_type.from_tensordict(td_out)