Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,31 @@ generator:
http_endpoint_host: "127.0.0.1"
http_endpoint_port: 8000
max_turns: 1
summarize_chat: false
context_folding:
enabled: false
trigger_ratio: 0.8
min_tokens: 256
max_folds: 2
keep_initial_prompt_tokens: -1
keep_last_messages: 0
include_summary_in_training: true
summary_role: "user"
summary_max_tokens: 256
summary_prompt: |
Your context window is full. Summarize the conversation so far so another model can continue.
Be concise and structured. Include:
- Objective
- Key facts/constraints
- Current plan and open questions
- Next action to take
Return only the summary wrapped in <summary></summary>.
summary_prefix: |
[Previous conversation summary]
{summary}

Please continue the task.
summary_sampling_params: {}

# chat template configuration
chat_template:
Expand Down
230 changes: 230 additions & 0 deletions skyrl-train/skyrl_train/generators/context_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import re

from loguru import logger
from omegaconf import DictConfig, OmegaConf

from skyrl_train.inference_engines.base import ConversationType, InferenceEngineInput
from skyrl_train.inference_engines.utils import get_sampling_params_for_backend


DEFAULT_SUMMARY_PROMPT = (
"Your context window is full. Summarize the conversation so far so another model can continue the task.\n"
"Be concise and structured. Include:\n"
"- Objective\n"
"- Key facts/constraints\n"
"- Current plan and open questions\n"
"- Next action to take\n"
"Return only the summary wrapped in <summary></summary>."
)


@dataclass
class FoldResult:
folded: bool
summary_text: Optional[str] = None
summary_prompt_ids: Optional[List[int]] = None
summary_output_ids: Optional[List[int]] = None
summary_logprobs: Optional[List[float]] = None
summary_stop_reason: Optional[str] = None
new_chat_history: Optional[ConversationType] = None
new_input_ids: Optional[List[int]] = None


class ContextFolder:
def __init__(
self,
cfg: DictConfig,
tokenizer,
inference_engine_client,
backend: str,
base_sampling_params: DictConfig,
chat_template_kwargs: Dict[str, Any],
):
self.cfg = cfg
self.tokenizer = tokenizer
self.inference_engine_client = inference_engine_client
self.backend = backend
self.base_sampling_params = base_sampling_params
self.chat_template_kwargs = chat_template_kwargs

self.enabled = bool(cfg.get("enabled", False))
self.trigger_ratio = float(cfg.get("trigger_ratio", 0.8))
self.min_tokens = int(cfg.get("min_tokens", 0))
self.max_folds = int(cfg.get("max_folds", 1))
self.keep_initial_prompt_tokens = int(cfg.get("keep_initial_prompt_tokens", -1))
self.keep_last_messages = int(cfg.get("keep_last_messages", 0))
summary_prompt = cfg.get("summary_prompt", None)
if summary_prompt is None:
summary_prompt = DEFAULT_SUMMARY_PROMPT
self.summary_prompt = str(summary_prompt)

summary_prefix = cfg.get("summary_prefix", None)
if summary_prefix is None:
summary_prefix = "[Previous conversation summary]\n{summary}\n\nPlease continue the task."
self.summary_prefix = str(summary_prefix)

self.summary_role = str(cfg.get("summary_role", "user"))
self.summary_max_tokens = cfg.get("summary_max_tokens", None)
self.summary_sampling_params = cfg.get("summary_sampling_params", None)
self.include_summary_in_training = bool(cfg.get("include_summary_in_training", False))
self.summary_pattern = re.compile(r"<summary>(.*?)</summary>", re.DOTALL)

if self.summary_role not in {"user", "system"}:
raise ValueError("context_folding.summary_role must be 'user' or 'system'")

def fold_trigger(self, current_input_length: int, max_input_length: int, fold_count: int) -> bool:
if not self.enabled:
return False
if fold_count >= self.max_folds:
return False
if current_input_length < self.min_tokens:
return False
threshold_length = int(max_input_length * self.trigger_ratio)
return current_input_length >= threshold_length

async def fold(
self,
chat_history: ConversationType,
current_input_length: int,
max_input_length: int,
initial_chat_history_length: int,
session_id: str,
fold_count: int,
) -> FoldResult:
if not self.fold_trigger(current_input_length, max_input_length, fold_count):
return FoldResult(folded=False)

keep_initial = self._resolve_keep_initial(initial_chat_history_length, len(chat_history))
keep_last = self._resolve_keep_last(keep_initial, len(chat_history))

if keep_initial + keep_last >= len(chat_history):
logger.debug("Context folding skipped: not enough history to summarize")
return FoldResult(folded=False)

summary_request, summary_prompt_ids, tail_messages = self._build_summary_request(
chat_history, keep_initial, keep_last, max_input_length
)
if summary_request is None or summary_prompt_ids is None:
logger.warning("Context folding skipped: summary prompt exceeds max input length")
return FoldResult(folded=False)

summary_sampling_params = self._build_summary_sampling_params()
summary_session_id = f"{session_id}_summary_{fold_count}"

engine_input = InferenceEngineInput(
prompt_token_ids=[summary_prompt_ids],
session_ids=[summary_session_id],
sampling_params=summary_sampling_params,
)
engine_output = await self.inference_engine_client.generate(engine_input)
summary_text = engine_output["responses"][0]
summary_output_ids = engine_output["response_ids"][0]
summary_stop_reason = engine_output["stop_reasons"][0]
summary_logprobs = None
if engine_output.get("response_logprobs") is not None:
summary_logprobs = engine_output["response_logprobs"][0]

summary_text = self._extract_summary(summary_text)
if not summary_text:
logger.warning("Context folding skipped: empty summary")
return FoldResult(folded=False)

summary_message = {
"role": self.summary_role,
"content": self._render_summary_prefix(summary_text),
}

initial_messages = chat_history[:keep_initial]
new_chat_history = initial_messages + [summary_message] + tail_messages
new_input_ids = self.tokenizer.apply_chat_template(
new_chat_history,
add_generation_prompt=True,
tokenize=True,
**self.chat_template_kwargs,
)

logger.info(
f"Context folded: {len(chat_history)} -> {len(new_chat_history)} messages "
f"(summary tokens: {len(summary_output_ids)})"
)

return FoldResult(
folded=True,
summary_text=summary_text,
summary_prompt_ids=summary_prompt_ids,
summary_output_ids=summary_output_ids,
summary_logprobs=summary_logprobs,
summary_stop_reason=summary_stop_reason,
new_chat_history=new_chat_history,
new_input_ids=new_input_ids,
)

def _resolve_keep_initial(self, initial_chat_history_length: int, total_messages: int) -> int:
keep_initial = self.keep_initial_prompt_tokens
if keep_initial < 0:
keep_initial = initial_chat_history_length
keep_initial = max(0, min(keep_initial, total_messages))
return keep_initial

def _resolve_keep_last(self, keep_initial: int, total_messages: int) -> int:
keep_last = max(0, self.keep_last_messages)
if keep_initial + keep_last > total_messages:
keep_last = max(0, total_messages - keep_initial)
return keep_last

def _build_summary_request(
self,
chat_history: ConversationType,
keep_initial: int,
keep_last: int,
max_input_length: int,
) -> Tuple[Optional[ConversationType], Optional[List[int]], Optional[ConversationType]]:
initial_messages = chat_history[:keep_initial]
tail_messages = chat_history[len(chat_history) - keep_last :] if keep_last > 0 else []
history_to_summarize = chat_history[keep_initial : len(chat_history) - keep_last]
if not history_to_summarize:
return None, None, None, None

summary_instruction = {"role": "user", "content": self.summary_prompt}
trimmed_history = list(history_to_summarize)

while True:
summary_request = initial_messages + trimmed_history + [summary_instruction]
summary_prompt_ids = self.tokenizer.apply_chat_template(
summary_request,
add_generation_prompt=True,
tokenize=True,
**self.chat_template_kwargs,
)
if len(summary_prompt_ids) <= max_input_length:
return summary_request, summary_prompt_ids, tail_messages
if not trimmed_history:
break
trimmed_history = trimmed_history[1:]

return None, None, None

def _build_summary_sampling_params(self) -> Optional[Dict[str, Any]]:
summary_cfg = OmegaConf.create({})
if self.base_sampling_params is not None:
summary_cfg = OmegaConf.merge(summary_cfg, self.base_sampling_params)
if self.summary_sampling_params is not None:
summary_cfg = OmegaConf.merge(summary_cfg, self.summary_sampling_params)
if self.summary_max_tokens is not None:
summary_cfg = OmegaConf.merge(summary_cfg, {"max_generate_length": int(self.summary_max_tokens)})
if len(summary_cfg) == 0:
return None
return get_sampling_params_for_backend(self.backend, summary_cfg)

def _extract_summary(self, summary_text: str) -> str:
match = self.summary_pattern.search(summary_text)
if match:
return match.group(1).strip()
return summary_text.strip()

def _render_summary_prefix(self, summary_text: str) -> str:
if "{summary}" in self.summary_prefix:
return self.summary_prefix.format(summary=summary_text)
return f"{self.summary_prefix}{summary_text}"
76 changes: 72 additions & 4 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import copy
import re
from uuid import uuid4
import skyrl_gym
from typing import List, Dict, Any, Optional, Union, Tuple
Expand All @@ -16,6 +17,7 @@
from loguru import logger

from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID
from skyrl_train.generators.context_folding import ContextFolder
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.base import InferenceEngineInput, ConversationType
from omegaconf import DictConfig
Expand Down Expand Up @@ -47,7 +49,6 @@ class StepWiseOutput:

step_outputs: List[TrajectoryOutput]


@dataclass
class AgentLoopState:
chat_history: ConversationType
Expand All @@ -57,7 +58,6 @@ class AgentLoopState:
response_end_idx: Optional[int]
done: bool


@dataclass
class TurnOutput:
output: str
Expand Down Expand Up @@ -123,6 +123,17 @@ def __init__(
self.custom_chat_template = get_custom_chat_template(generator_cfg.chat_template)
# get generation prompt ids for the tokenizer if needed
self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None
self.context_folder = None
context_folding_cfg = getattr(generator_cfg, "context_folding", None)
if context_folding_cfg is not None and context_folding_cfg.get("enabled", False):
self.context_folder = ContextFolder(
context_folding_cfg,
tokenizer=self.tokenizer,
inference_engine_client=self.inference_engine_client,
backend=self.generator_cfg.backend,
base_sampling_params=self.generator_cfg.sampling_params,
chat_template_kwargs=self.generator_cfg.chat_template_kwargs,
)
if self.skyrl_gym_cfg.max_env_workers > 0:
self.env_executor = ThreadPoolExecutor(
max_workers=self.skyrl_gym_cfg.max_env_workers, thread_name_prefix="skyrl-gym-env-"
Expand Down Expand Up @@ -176,6 +187,13 @@ def _validate_cfg(self, generator_cfg: DictConfig):
if not self.use_conversation_multi_turn:
raise ValueError("`step_wise_trajectories` doesn't support `use_conversation_multi_turn=False`")

context_folding_cfg = getattr(generator_cfg, "context_folding", None)
if context_folding_cfg is not None and context_folding_cfg.get("enabled", False):
if not self.generator_cfg.step_wise_trajectories:
raise ValueError("`context_folding.enabled` requires `step_wise_trajectories=True`")
if self.custom_chat_template is not None:
raise ValueError("`context_folding.enabled` doesn't support custom chat templates")

async def _run_in_executor_if_available(self, func, *args, **kwargs):
if (executor := self.env_executor) is not None:
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -268,13 +286,15 @@ async def agent_loop(
response_end_idx=None,
done=False,
)
context_folder = self.context_folder
fold_count = 0

while not agent_loop_state.done:

if len(agent_loop_state.input_ids) > max_input_length:
stop_reason = "length"
break

# 1. Generate output
if is_step_wise or retokenize_chat_history:
# re-apply whole chat template so length check is correct
Expand Down Expand Up @@ -387,6 +407,54 @@ async def agent_loop(

per_step_rewards.append((step_reward, agent_loop_state.response_end_idx))

if context_folder is not None and not agent_loop_state.done:
fold_result = await context_folder.fold(
chat_history=agent_loop_state.chat_history,
current_input_length=len(agent_loop_state.input_ids),
max_input_length=max_input_length,
initial_chat_history_length=initial_chat_history_length,
session_id=session_id,
fold_count=fold_count,
)
if fold_result.folded:
fold_count += 1
if is_step_wise and context_folder.include_summary_in_training:
if fold_result.summary_prompt_ids is None or fold_result.summary_output_ids is None:
raise ValueError("Context folding summary output is missing prompt or response IDs")
summary_turn_output = TurnOutput(
output=fold_result.summary_text,
output_ids=fold_result.summary_output_ids,
output_logprobs=fold_result.summary_logprobs,
new_obs=[],
reward=0.0,
obs_ids=[],
)
if not summary_turn_output.output_ids:
logger.warning("Context folding summary produced empty output IDs; skipping summary step")
else:
summary_loss_mask = summary_turn_output.get_turn_loss_mask()
summary_rollout_logprobs = summary_turn_output.get_turn_rollout_logprobs()
summary_step_output = TrajectoryOutput(
response_ids=summary_turn_output.output_ids,
reward=0.0,
stop_reason=fold_result.summary_stop_reason or "summary",
loss_mask=summary_loss_mask,
prompt_ids=fold_result.summary_prompt_ids,
rollout_logprobs=summary_rollout_logprobs,
env_metrics={},
)
agent_loop_output.step_outputs.append(summary_step_output)
per_step_rewards.append((0.0, len(summary_turn_output.output_ids) - 1))
elif context_folder.include_summary_in_training:
logger.warning(
"Context folding summary steps require `step_wise_trajectories=True`; skipping summary output"
)

agent_loop_state.chat_history = fold_result.new_chat_history
if fold_result.new_input_ids is not None:
agent_loop_state.input_ids = fold_result.new_input_ids
agent_loop_state.response_end_idx = None

# Get environment-specific metrics after the episode is done
env_metrics = env.get_metrics()
# Close the environment
Expand Down
Loading
Loading