Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ generator:
http_endpoint_host: "127.0.0.1"
http_endpoint_port: 8000
max_turns: 1
summarize_chat: false

# chat template configuration
chat_template:
Expand Down
91 changes: 89 additions & 2 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import copy
import re
from uuid import uuid4
import skyrl_gym
from typing import List, Dict, Any, Optional, Union, Tuple
Expand Down Expand Up @@ -47,7 +48,6 @@ class StepWiseOutput:

step_outputs: List[TrajectoryOutput]


@dataclass
class AgentLoopState:
chat_history: ConversationType
Expand All @@ -57,6 +57,32 @@ class AgentLoopState:
response_end_idx: Optional[int]
done: bool

def summarize_chat_history(self, initial_chat_history_length: int) -> ConversationType:
"""
Summarize the chat history.
"""
summary_prompt = """
Your operational context is full. Generate a concise summary by populating the template below.
This summary will be your sole context for continuing this task. Be brief but ensure all critical data is present.
- Mission Objective
– Original query: [State the user's verbatim query.]
– Verification checklist: [Status (VERIFIED/PENDING)] [Checklist item]
- Key Findings
– Sources: [List the most critical, verified facts with sources.]
– Discrepancies: [Note any conflicting information found between sources.]
- Tactical Plan
- Promising leads: [List the best remaining keywords, sources, or angles to investigate.]
– Known dead ends: [List queries or sources that proved useless to avoid repetition.]
– Immediate next action: [State the exact tool call or query you were about to execute next.]
Now generate the summary, and put your summary inside tag <summary></summary>.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large summary prompt is hardcoded within the summarize_chat_history method. For better readability and maintainability, consider moving it to a module-level constant (e.g., _SUMMARY_PROMPT) at the top of the file.


history_to_summarize = self.chat_history[initial_chat_history_length:]
summarize_request = self.chat_history[:initial_chat_history_length].copy()
summarize_request.extend(history_to_summarize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to create summarize_request can be simplified. These three lines are equivalent to self.chat_history.copy().

Suggested change
history_to_summarize = self.chat_history[initial_chat_history_length:]
summarize_request = self.chat_history[:initial_chat_history_length].copy()
summarize_request.extend(history_to_summarize)
summarize_request = self.chat_history.copy()

summarize_request.append({"role": "user", "content": summary_prompt})

return summarize_request

@dataclass
class TurnOutput:
Expand Down Expand Up @@ -119,6 +145,7 @@ def __init__(
self.max_turns = generator_cfg.max_turns
self.batched = generator_cfg.batched
self.use_conversation_multi_turn = generator_cfg.use_conversation_multi_turn
self.summarize_chat = generator_cfg.summarize_chat
# optionally use custom chat template to get loss masks (i.e. for Qwen3)
self.custom_chat_template = get_custom_chat_template(generator_cfg.chat_template)
# get generation prompt ids for the tokenizer if needed
Expand Down Expand Up @@ -182,6 +209,58 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs):
return await loop.run_in_executor(executor, func, *args, **kwargs)
else:
return func(*args, **kwargs)

async def _summarize_and_compress_history(
self,
agent_loop_state: AgentLoopState,
initial_chat_history_length: int,
session_id: str,
sampling_params: Optional[Dict[str, Any]] = None,
) -> AgentLoopState:

summarize_request = agent_loop_state.summarize_chat_history(initial_chat_history_length)

summary_input_ids = self.tokenizer.apply_chat_template(
summarize_request,
add_generation_prompt=True,
tokenize=True,
**self.generator_cfg.chat_template_kwargs,
)

engine_input = InferenceEngineInput(
prompt_token_ids=[summary_input_ids],
session_ids=[session_id],
sampling_params=sampling_params
)

engine_output = await self.inference_engine_client.generate(engine_input)
summary_text = engine_output["responses"][0]

match = re.search(r'<summary>(.*?)</summary>', summary_text, re.DOTALL)
if match:
summary_text = match.group(1).strip()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the language model does not return the summary within <summary> tags, re.search will return None, and the original, unparsed summary_text will be used. This could lead to a malformed context for the next turn. It's safer to handle this case, for instance, by logging a warning.

Suggested change
if match:
summary_text = match.group(1).strip()
if match:
summary_text = match.group(1).strip()
else:
logger.warning("Could not find <summary> tags in the summarization response. Using the full response as summary.")


new_chat_history = agent_loop_state.chat_history[:initial_chat_history_length].copy()
new_chat_history.append({
"role": "user",
"content": f"[Previous conversation summary]:\n{summary_text}\n\nPlease continue the task."
})

# re-tokenize the compressed history
new_input_ids = self.tokenizer.apply_chat_template(
new_chat_history,
add_generation_prompt=True,
tokenize=True,
**self.generator_cfg.chat_template_kwargs,
)

agent_loop_state.chat_history = new_chat_history
agent_loop_state.input_ids = new_input_ids
agent_loop_state.loss_mask = []
if agent_loop_state.rollout_logprobs is not None:
agent_loop_state.rollout_logprobs = []

return agent_loop_state

async def agent_loop(
self,
Expand Down Expand Up @@ -269,12 +348,20 @@ async def agent_loop(
done=False,
)

threshold = int(max_input_length * 0.8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The summarization threshold is hardcoded as 80% of max_input_length. This magic number makes the code harder to maintain. It's better to make it a configurable parameter with a default value.

Suggested change
threshold = int(max_input_length * 0.8)
threshold = int(max_input_length * self.generator_cfg.get("summarization_threshold_ratio", 0.8))


while not agent_loop_state.done:

logger.info(f"[AgentLoop] Running context length (tokens): {len(agent_loop_state.input_ids)}")

if self.summarize_chat and len(agent_loop_state.input_ids) > threshold:
agent_loop_state = await self._summarize_and_compress_history(agent_loop_state, initial_chat_history_length, session_id, sampling_params)
logger.info(f"Summarized chat history. New length: {len(agent_loop_state.input_ids)}")

if len(agent_loop_state.input_ids) > max_input_length:
stop_reason = "length"
break

# 1. Generate output
if is_step_wise or retokenize_chat_history:
# re-apply whole chat template so length check is correct
Expand Down
Loading