diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 012b4e43d..8b631cbff 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -235,6 +235,31 @@ generator: http_endpoint_host: "127.0.0.1" http_endpoint_port: 8000 max_turns: 1 + summarize_chat: false + context_folding: + enabled: false + trigger_ratio: 0.8 + min_tokens: 256 + max_folds: 2 + keep_initial_prompt_tokens: -1 + keep_last_messages: 0 + include_summary_in_training: true + summary_role: "user" + summary_max_tokens: 256 + summary_prompt: | + Your context window is full. Summarize the conversation so far so another model can continue. + Be concise and structured. Include: + - Objective + - Key facts/constraints + - Current plan and open questions + - Next action to take + Return only the summary wrapped in . + summary_prefix: | + [Previous conversation summary] + {summary} + + Please continue the task. + summary_sampling_params: {} # chat template configuration chat_template: diff --git a/skyrl-train/skyrl_train/generators/context_folding.py b/skyrl-train/skyrl_train/generators/context_folding.py new file mode 100644 index 000000000..3614967d0 --- /dev/null +++ b/skyrl-train/skyrl_train/generators/context_folding.py @@ -0,0 +1,230 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple +import re + +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from skyrl_train.inference_engines.base import ConversationType, InferenceEngineInput +from skyrl_train.inference_engines.utils import get_sampling_params_for_backend + + +DEFAULT_SUMMARY_PROMPT = ( + "Your context window is full. Summarize the conversation so far so another model can continue the task.\n" + "Be concise and structured. Include:\n" + "- Objective\n" + "- Key facts/constraints\n" + "- Current plan and open questions\n" + "- Next action to take\n" + "Return only the summary wrapped in ." +) + + +@dataclass +class FoldResult: + folded: bool + summary_text: Optional[str] = None + summary_prompt_ids: Optional[List[int]] = None + summary_output_ids: Optional[List[int]] = None + summary_logprobs: Optional[List[float]] = None + summary_stop_reason: Optional[str] = None + new_chat_history: Optional[ConversationType] = None + new_input_ids: Optional[List[int]] = None + + +class ContextFolder: + def __init__( + self, + cfg: DictConfig, + tokenizer, + inference_engine_client, + backend: str, + base_sampling_params: DictConfig, + chat_template_kwargs: Dict[str, Any], + ): + self.cfg = cfg + self.tokenizer = tokenizer + self.inference_engine_client = inference_engine_client + self.backend = backend + self.base_sampling_params = base_sampling_params + self.chat_template_kwargs = chat_template_kwargs + + self.enabled = bool(cfg.get("enabled", False)) + self.trigger_ratio = float(cfg.get("trigger_ratio", 0.8)) + self.min_tokens = int(cfg.get("min_tokens", 0)) + self.max_folds = int(cfg.get("max_folds", 1)) + self.keep_initial_prompt_tokens = int(cfg.get("keep_initial_prompt_tokens", -1)) + self.keep_last_messages = int(cfg.get("keep_last_messages", 0)) + summary_prompt = cfg.get("summary_prompt", None) + if summary_prompt is None: + summary_prompt = DEFAULT_SUMMARY_PROMPT + self.summary_prompt = str(summary_prompt) + + summary_prefix = cfg.get("summary_prefix", None) + if summary_prefix is None: + summary_prefix = "[Previous conversation summary]\n{summary}\n\nPlease continue the task." + self.summary_prefix = str(summary_prefix) + + self.summary_role = str(cfg.get("summary_role", "user")) + self.summary_max_tokens = cfg.get("summary_max_tokens", None) + self.summary_sampling_params = cfg.get("summary_sampling_params", None) + self.include_summary_in_training = bool(cfg.get("include_summary_in_training", False)) + self.summary_pattern = re.compile(r"(.*?)", re.DOTALL) + + if self.summary_role not in {"user", "system"}: + raise ValueError("context_folding.summary_role must be 'user' or 'system'") + + def fold_trigger(self, current_input_length: int, max_input_length: int, fold_count: int) -> bool: + if not self.enabled: + return False + if fold_count >= self.max_folds: + return False + if current_input_length < self.min_tokens: + return False + threshold_length = int(max_input_length * self.trigger_ratio) + return current_input_length >= threshold_length + + async def fold( + self, + chat_history: ConversationType, + current_input_length: int, + max_input_length: int, + initial_chat_history_length: int, + session_id: str, + fold_count: int, + ) -> FoldResult: + if not self.fold_trigger(current_input_length, max_input_length, fold_count): + return FoldResult(folded=False) + + keep_initial = self._resolve_keep_initial(initial_chat_history_length, len(chat_history)) + keep_last = self._resolve_keep_last(keep_initial, len(chat_history)) + + if keep_initial + keep_last >= len(chat_history): + logger.debug("Context folding skipped: not enough history to summarize") + return FoldResult(folded=False) + + summary_request, summary_prompt_ids, tail_messages = self._build_summary_request( + chat_history, keep_initial, keep_last, max_input_length + ) + if summary_request is None or summary_prompt_ids is None: + logger.warning("Context folding skipped: summary prompt exceeds max input length") + return FoldResult(folded=False) + + summary_sampling_params = self._build_summary_sampling_params() + summary_session_id = f"{session_id}_summary_{fold_count}" + + engine_input = InferenceEngineInput( + prompt_token_ids=[summary_prompt_ids], + session_ids=[summary_session_id], + sampling_params=summary_sampling_params, + ) + engine_output = await self.inference_engine_client.generate(engine_input) + summary_text = engine_output["responses"][0] + summary_output_ids = engine_output["response_ids"][0] + summary_stop_reason = engine_output["stop_reasons"][0] + summary_logprobs = None + if engine_output.get("response_logprobs") is not None: + summary_logprobs = engine_output["response_logprobs"][0] + + summary_text = self._extract_summary(summary_text) + if not summary_text: + logger.warning("Context folding skipped: empty summary") + return FoldResult(folded=False) + + summary_message = { + "role": self.summary_role, + "content": self._render_summary_prefix(summary_text), + } + + initial_messages = chat_history[:keep_initial] + new_chat_history = initial_messages + [summary_message] + tail_messages + new_input_ids = self.tokenizer.apply_chat_template( + new_chat_history, + add_generation_prompt=True, + tokenize=True, + **self.chat_template_kwargs, + ) + + logger.info( + f"Context folded: {len(chat_history)} -> {len(new_chat_history)} messages " + f"(summary tokens: {len(summary_output_ids)})" + ) + + return FoldResult( + folded=True, + summary_text=summary_text, + summary_prompt_ids=summary_prompt_ids, + summary_output_ids=summary_output_ids, + summary_logprobs=summary_logprobs, + summary_stop_reason=summary_stop_reason, + new_chat_history=new_chat_history, + new_input_ids=new_input_ids, + ) + + def _resolve_keep_initial(self, initial_chat_history_length: int, total_messages: int) -> int: + keep_initial = self.keep_initial_prompt_tokens + if keep_initial < 0: + keep_initial = initial_chat_history_length + keep_initial = max(0, min(keep_initial, total_messages)) + return keep_initial + + def _resolve_keep_last(self, keep_initial: int, total_messages: int) -> int: + keep_last = max(0, self.keep_last_messages) + if keep_initial + keep_last > total_messages: + keep_last = max(0, total_messages - keep_initial) + return keep_last + + def _build_summary_request( + self, + chat_history: ConversationType, + keep_initial: int, + keep_last: int, + max_input_length: int, + ) -> Tuple[Optional[ConversationType], Optional[List[int]], Optional[ConversationType]]: + initial_messages = chat_history[:keep_initial] + tail_messages = chat_history[len(chat_history) - keep_last :] if keep_last > 0 else [] + history_to_summarize = chat_history[keep_initial : len(chat_history) - keep_last] + if not history_to_summarize: + return None, None, None, None + + summary_instruction = {"role": "user", "content": self.summary_prompt} + trimmed_history = list(history_to_summarize) + + while True: + summary_request = initial_messages + trimmed_history + [summary_instruction] + summary_prompt_ids = self.tokenizer.apply_chat_template( + summary_request, + add_generation_prompt=True, + tokenize=True, + **self.chat_template_kwargs, + ) + if len(summary_prompt_ids) <= max_input_length: + return summary_request, summary_prompt_ids, tail_messages + if not trimmed_history: + break + trimmed_history = trimmed_history[1:] + + return None, None, None + + def _build_summary_sampling_params(self) -> Optional[Dict[str, Any]]: + summary_cfg = OmegaConf.create({}) + if self.base_sampling_params is not None: + summary_cfg = OmegaConf.merge(summary_cfg, self.base_sampling_params) + if self.summary_sampling_params is not None: + summary_cfg = OmegaConf.merge(summary_cfg, self.summary_sampling_params) + if self.summary_max_tokens is not None: + summary_cfg = OmegaConf.merge(summary_cfg, {"max_generate_length": int(self.summary_max_tokens)}) + if len(summary_cfg) == 0: + return None + return get_sampling_params_for_backend(self.backend, summary_cfg) + + def _extract_summary(self, summary_text: str) -> str: + match = self.summary_pattern.search(summary_text) + if match: + return match.group(1).strip() + return summary_text.strip() + + def _render_summary_prefix(self, summary_text: str) -> str: + if "{summary}" in self.summary_prefix: + return self.summary_prefix.format(summary=summary_text) + return f"{self.summary_prefix}{summary_text}" diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 9988ec880..512b96392 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -7,6 +7,7 @@ import asyncio import copy +import re from uuid import uuid4 import skyrl_gym from typing import List, Dict, Any, Optional, Union, Tuple @@ -16,6 +17,7 @@ from loguru import logger from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID +from skyrl_train.generators.context_folding import ContextFolder from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient from skyrl_train.inference_engines.base import InferenceEngineInput, ConversationType from omegaconf import DictConfig @@ -47,7 +49,6 @@ class StepWiseOutput: step_outputs: List[TrajectoryOutput] - @dataclass class AgentLoopState: chat_history: ConversationType @@ -57,7 +58,6 @@ class AgentLoopState: response_end_idx: Optional[int] done: bool - @dataclass class TurnOutput: output: str @@ -123,6 +123,17 @@ def __init__( self.custom_chat_template = get_custom_chat_template(generator_cfg.chat_template) # get generation prompt ids for the tokenizer if needed self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None + self.context_folder = None + context_folding_cfg = getattr(generator_cfg, "context_folding", None) + if context_folding_cfg is not None and context_folding_cfg.get("enabled", False): + self.context_folder = ContextFolder( + context_folding_cfg, + tokenizer=self.tokenizer, + inference_engine_client=self.inference_engine_client, + backend=self.generator_cfg.backend, + base_sampling_params=self.generator_cfg.sampling_params, + chat_template_kwargs=self.generator_cfg.chat_template_kwargs, + ) if self.skyrl_gym_cfg.max_env_workers > 0: self.env_executor = ThreadPoolExecutor( max_workers=self.skyrl_gym_cfg.max_env_workers, thread_name_prefix="skyrl-gym-env-" @@ -176,6 +187,13 @@ def _validate_cfg(self, generator_cfg: DictConfig): if not self.use_conversation_multi_turn: raise ValueError("`step_wise_trajectories` doesn't support `use_conversation_multi_turn=False`") + context_folding_cfg = getattr(generator_cfg, "context_folding", None) + if context_folding_cfg is not None and context_folding_cfg.get("enabled", False): + if not self.generator_cfg.step_wise_trajectories: + raise ValueError("`context_folding.enabled` requires `step_wise_trajectories=True`") + if self.custom_chat_template is not None: + raise ValueError("`context_folding.enabled` doesn't support custom chat templates") + async def _run_in_executor_if_available(self, func, *args, **kwargs): if (executor := self.env_executor) is not None: loop = asyncio.get_running_loop() @@ -268,13 +286,15 @@ async def agent_loop( response_end_idx=None, done=False, ) + context_folder = self.context_folder + fold_count = 0 while not agent_loop_state.done: - + if len(agent_loop_state.input_ids) > max_input_length: stop_reason = "length" break - + # 1. Generate output if is_step_wise or retokenize_chat_history: # re-apply whole chat template so length check is correct @@ -387,6 +407,54 @@ async def agent_loop( per_step_rewards.append((step_reward, agent_loop_state.response_end_idx)) + if context_folder is not None and not agent_loop_state.done: + fold_result = await context_folder.fold( + chat_history=agent_loop_state.chat_history, + current_input_length=len(agent_loop_state.input_ids), + max_input_length=max_input_length, + initial_chat_history_length=initial_chat_history_length, + session_id=session_id, + fold_count=fold_count, + ) + if fold_result.folded: + fold_count += 1 + if is_step_wise and context_folder.include_summary_in_training: + if fold_result.summary_prompt_ids is None or fold_result.summary_output_ids is None: + raise ValueError("Context folding summary output is missing prompt or response IDs") + summary_turn_output = TurnOutput( + output=fold_result.summary_text, + output_ids=fold_result.summary_output_ids, + output_logprobs=fold_result.summary_logprobs, + new_obs=[], + reward=0.0, + obs_ids=[], + ) + if not summary_turn_output.output_ids: + logger.warning("Context folding summary produced empty output IDs; skipping summary step") + else: + summary_loss_mask = summary_turn_output.get_turn_loss_mask() + summary_rollout_logprobs = summary_turn_output.get_turn_rollout_logprobs() + summary_step_output = TrajectoryOutput( + response_ids=summary_turn_output.output_ids, + reward=0.0, + stop_reason=fold_result.summary_stop_reason or "summary", + loss_mask=summary_loss_mask, + prompt_ids=fold_result.summary_prompt_ids, + rollout_logprobs=summary_rollout_logprobs, + env_metrics={}, + ) + agent_loop_output.step_outputs.append(summary_step_output) + per_step_rewards.append((0.0, len(summary_turn_output.output_ids) - 1)) + elif context_folder.include_summary_in_training: + logger.warning( + "Context folding summary steps require `step_wise_trajectories=True`; skipping summary output" + ) + + agent_loop_state.chat_history = fold_result.new_chat_history + if fold_result.new_input_ids is not None: + agent_loop_state.input_ids = fold_result.new_input_ids + agent_loop_state.response_end_idx = None + # Get environment-specific metrics after the episode is done env_metrics = env.get_metrics() # Close the environment diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_context.py b/skyrl-train/tests/cpu/generators/test_skyrl_context.py new file mode 100644 index 000000000..c666ec40d --- /dev/null +++ b/skyrl-train/tests/cpu/generators/test_skyrl_context.py @@ -0,0 +1,171 @@ +import asyncio +import os +from typing import List, Dict, Any +import pytest +from omegaconf import DictConfig +from transformers import AutoTokenizer +from skyrl_train.generators.context_folding import ContextFolder + +# Mock inference engine client that mimics the real inference engine's generate() method +class MockInferenceEngineClient: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.call_count = 0 + + async def generate(self, engine_input): + """Mock the inference engine's generate method for summarization""" + self.call_count += 1 + + # engine_input is a dict, not an object + prompt_ids = engine_input["prompt_token_ids"][0] + prompt_text = self.tokenizer.decode(prompt_ids) + + print(f"\n๐Ÿ“ž Inference Engine Called (call #{self.call_count})") + print(f"๐Ÿ“ Prompt length: {len(prompt_ids)} tokens") + print(f"๐Ÿ“ Session ID: {engine_input['session_ids'][0]}") + print(f"๐Ÿ“ Sampling params: {engine_input['sampling_params']}") + print(f"๐Ÿ“‹ Prompt preview (first 200 chars):\n{prompt_text[:200]}...") + + # Generate a mock summary response wrapped in tags + summary_text = ( + "" + "The conversation covered machine learning fundamentals including supervised, " + "unsupervised, and reinforcement learning. Discussed neural networks with " + "backpropagation and deep learning. Explained transformer architecture with " + "attention mechanisms and their effectiveness in NLP tasks. User now asking " + "about implementation details." + "" + ) + + # Encode the summary + summary_ids = self.tokenizer.encode(summary_text, add_special_tokens=False) + + # Return in the format expected by ContextFolder + return { + "responses": [summary_text], + "response_ids": [summary_ids], + "stop_reasons": ["stop"], + "response_logprobs": [[0.0] * len(summary_ids)] # Mock logprobs + } + +@pytest.mark.asyncio +async def test_context_folding(): + print("๐Ÿงช Testing Context Folding...") + + # Setup tokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Setup mock inference client + mock_client = MockInferenceEngineClient(tokenizer) + + # Context folding configuration + folding_cfg = DictConfig({ + "enabled": True, + "trigger_ratio": 0.7, + "min_tokens": 50, + "max_folds": 3, + "summary_max_tokens": 200, + "summary_prompt": "Your context window is full. Summarize the conversation so far. Wrap your summary in tags.", + "summary_prefix": "[Previous conversation summary]\n{summary}\n\nPlease continue.", + "summary_role": "user", + "keep_initial_prompt_tokens": 1, # Keep system message + "keep_last_messages": 2, # Keep last 2 messages + "include_summary_in_training": False + }) + + # Create context folder with complete sampling params for vLLM + base_sampling_params = DictConfig({ + "temperature": 0.7, + "top_p": 0.9, + "top_k": -1, + "min_p": 0.0, + "logprobs": None, + "max_generate_length": 100 + }) + + context_folder = ContextFolder( + cfg=folding_cfg, + tokenizer=tokenizer, + inference_engine_client=mock_client, + backend="vllm", + base_sampling_params=base_sampling_params, + chat_template_kwargs={} + ) + + # Create mock conversation that's long enough to trigger folding + mock_conversation = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me about machine learning algorithms."}, + {"role": "assistant", "content": "Machine learning algorithms are computational methods that enable computers to learn patterns from data without being explicitly programmed. There are several main categories: supervised learning (like linear regression and decision trees), unsupervised learning (like clustering and dimensionality reduction), and reinforcement learning (where agents learn through trial and error). Each type has different use cases and strengths."}, + {"role": "user", "content": "Can you explain neural networks in more detail?"}, + {"role": "assistant", "content": "Neural networks are inspired by biological neural networks in the brain. They consist of interconnected nodes (neurons) organized in layers. Each connection has a weight that determines its strength. During training, these weights are adjusted using backpropagation to minimize prediction errors. Deep neural networks with multiple hidden layers can learn complex patterns and representations, making them powerful for tasks like image recognition, natural language processing, and game playing."}, + {"role": "user", "content": "What about transformers?"}, + {"role": "assistant", "content": "Transformers revolutionized natural language processing and are the architecture behind models like GPT and BERT. Key innovations include the attention mechanism, which allows the model to focus on relevant parts of the input sequence, and parallel processing capabilities. The self-attention mechanism computes relationships between all positions in a sequence simultaneously, making transformers very effective for understanding context and long-range dependencies in text."}, + {"role": "user", "content": "How do I implement a simple transformer?"}, + ] + + # Calculate token length to see if it triggers folding + full_text = tokenizer.apply_chat_template(mock_conversation, tokenize=False) + input_ids = tokenizer.encode(full_text) + current_length = len(input_ids) + max_length = 300 # Small max length to force folding + + print(f"๐Ÿ“Š Current conversation length: {current_length} tokens") + print(f"๐Ÿ“Š Max allowed length: {max_length} tokens") + print(f"๐Ÿ“Š Trigger ratio: {folding_cfg.trigger_ratio}") + print(f"๐Ÿ“Š Trigger threshold: {int(max_length * folding_cfg.trigger_ratio)} tokens") + + # Test fold trigger + should_fold = context_folder.fold_trigger( + current_input_length=current_length, + max_input_length=max_length, + fold_count=0 + ) + + print(f"๐Ÿค” Should fold? {should_fold}") + + if should_fold: + print("\n๐Ÿ”„ Attempting to fold context...") + + # Test the actual folding + fold_result = await context_folder.fold( + chat_history=mock_conversation, + current_input_length=current_length, + max_input_length=max_length, + initial_chat_history_length=len(mock_conversation), + session_id="test_session", + fold_count=0 + ) + + if fold_result.folded: + print("โœ… Context folding successful!") + print(f"๐Ÿ“ Original messages: {len(mock_conversation)}") + print(f"๐Ÿ“ After folding: {len(fold_result.new_chat_history)}") + print(f"๐ŸŽฏ Summary tokens: {len(fold_result.summary_output_ids) if fold_result.summary_output_ids else 0}") + print(f"๐Ÿ” Extracted summary text: {fold_result.summary_text}") + + print("\n๐Ÿ“‹ New conversation structure:") + for i, msg in enumerate(fold_result.new_chat_history): + role = msg["role"] + content = msg["content"][:150] + "..." if len(msg["content"]) > 150 else msg["content"] + print(f" {i}: [{role}] {content}") + + print(f"\n๐Ÿ“Š Token counts:") + print(f" Before: {current_length} tokens") + new_length = len(fold_result.new_input_ids) + print(f" After: {new_length} tokens") + print(f" Saved: {current_length - new_length} tokens ({100 * (current_length - new_length) / current_length:.1f}%)") + + print(f"\n๐Ÿ”ง Inference engine stats:") + print(f" Total calls: {mock_client.call_count}") + else: + print("โŒ Context folding did not occur") + else: + print("โ„น๏ธ Context folding not triggered (conversation too short)") + + print("\n๐ŸŽ‰ Test completed!") + +if __name__ == "__main__": + asyncio.run(test_context_folding()) \ No newline at end of file