Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ neo4j.lock

**/temp_workspace/
ms_agent/app/temp_workspace/

# webui
webui/work_dir/
79 changes: 79 additions & 0 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import os.path
import sys
import threading
import uuid
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -84,6 +85,8 @@ class LLMAgent(Agent):

TOTAL_PROMPT_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0
TOTAL_CACHED_TOKENS = 0
TOTAL_CACHE_CREATION_INPUT_TOKENS = 0
TOKEN_LOCK = asyncio.Lock()

def __init__(self,
Expand Down Expand Up @@ -538,6 +541,41 @@ def stream(self):
DictConfig({}))
return getattr(generation_config, 'stream', False)

@property
def show_reasoning(self) -> bool:
"""Whether to print model reasoning/thinking content in stream mode.

Notes:
- This only affects local console output.
- Reasoning is carried by `Message.reasoning_content` (if the backend provides it).
"""
generation_config = getattr(self.config, 'generation_config',
DictConfig({}))
return bool(getattr(generation_config, 'show_reasoning', False))

@property
def reasoning_output(self) -> str:
"""Where to print reasoning content when `show_reasoning=True`.

Supported values:
- "stderr" (default): keep stdout clean for assistant final text
- "stdout": interleave reasoning with assistant output on stdout
Comment on lines +560 to +562
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for reasoning_output in the code is 'stdout', but the docstring states that the default is "stderr". This discrepancy should be corrected to ensure the code and documentation are consistent. The docstring's suggestion of stderr as the default seems more sensible to keep the main output clean.

Suggested change
Supported values:
- "stderr" (default): keep stdout clean for assistant final text
- "stdout": interleave reasoning with assistant output on stdout
return str(getattr(generation_config, 'reasoning_output', 'stderr'))

"""
generation_config = getattr(self.config, 'generation_config',
DictConfig({}))
return str(getattr(generation_config, 'reasoning_output', 'stdout'))

def _write_reasoning(self, text: str):
if not text:
return
if self.reasoning_output.lower() == 'stdout':
sys.stdout.write(text)
sys.stdout.flush()
else:
# default: stderr
sys.stderr.write(text)
sys.stderr.flush()

@property
def system(self):
return getattr(
Expand Down Expand Up @@ -753,22 +791,49 @@ async def step(
if self.stream:
self.log_output('[assistant]:')
_content = ''
_reasoning = ''
is_first = True
_response_message = None
_printed_reasoning_header = False
for _response_message in self.llm.generate(
messages, tools=tools):
if is_first:
messages.append(_response_message)
is_first = False

# Optional: stream model "thinking/reasoning" if available.
if self.show_reasoning:
reasoning_text = getattr(_response_message,
'reasoning_content', '') or ''
# Some providers may reset / shorten content across chunks.
if len(reasoning_text) < len(_reasoning):
_reasoning = ''
new_reasoning = reasoning_text[len(_reasoning):]
if new_reasoning:
if not _printed_reasoning_header:
self._write_reasoning('[thinking]:\n')
_printed_reasoning_header = True
self._write_reasoning(new_reasoning)
_reasoning = reasoning_text

new_content = _response_message.content[len(_content):]
sys.stdout.write(new_content)
sys.stdout.flush()
_content = _response_message.content
messages[-1] = _response_message
yield messages
if self.show_reasoning and _printed_reasoning_header:
self._write_reasoning('\n')
sys.stdout.write('\n')
else:
_response_message = self.llm.generate(messages, tools=tools)
if self.show_reasoning:
reasoning_text = getattr(_response_message,
'reasoning_content', '') or ''
if reasoning_text:
self._write_reasoning('[thinking]:\n')
self._write_reasoning(reasoning_text)
self._write_reasoning('\n')
if _response_message.content:
self.log_output('[assistant]:')
self.log_output(_response_message.content)
Expand All @@ -791,19 +856,33 @@ async def step(
# usage
prompt_tokens = _response_message.prompt_tokens
completion_tokens = _response_message.completion_tokens
cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0
cache_creation_input_tokens = getattr(
_response_message, 'cache_creation_input_tokens', 0) or 0

async with LLMAgent.TOKEN_LOCK:
LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens
LLMAgent.TOTAL_COMPLETION_TOKENS += completion_tokens
LLMAgent.TOTAL_CACHED_TOKENS += cached_tokens
LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS += cache_creation_input_tokens

# tokens in the current step
self.log_output(
f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}'
)
if cached_tokens or cache_creation_input_tokens:
self.log_output(
f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}'
)
# total tokens for the process so far
self.log_output(
f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, '
f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}')
if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS:
self.log_output(
f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, '
f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}'
)

yield messages

Expand Down
153 changes: 152 additions & 1 deletion ms_agent/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class OpenAI(LLM):
'role', 'content', 'tool_calls', 'partial', 'prefix', 'tool_call_id'
}

# Providers that support cache_control in structured content blocks
CACHE_CONTROL_PROVIDERS = ['dashscope', 'anthropic']

def __init__(
self,
config: DictConfig,
Expand All @@ -51,9 +54,85 @@ def __init__(
api_key=api_key,
base_url=base_url,
)
self.base_url = base_url or ''
self.args: Dict = OmegaConf.to_container(
getattr(config, 'generation_config', DictConfig({})))

# Prefix cache configuration
# - force_prefix_cache: enable structured content with cache_control for explicit caching
# - prefix_cache_roles: which messages to cache (only these are converted to structured format)
# Supports:
# - Role names: 'system', 'user', 'assistant', 'tool'
# - Special values: 'last_message' (only cache the last message in the list)
# Default: ['system'] - system prompt is usually the longest stable prefix
self._prefix_cache_enabled = self.args.get('force_prefix_cache', False)
self._prefix_cache_roles = set(
self.args.get('prefix_cache_roles', ['system']))
self._prefix_cache_provider = self._detect_cache_provider()

def _detect_cache_provider(self) -> Optional[str]:
"""
Detect which provider-specific cache_control format to use based on base_url.

Returns:
Provider name (e.g. 'dashscope', 'anthropic') or None for native OpenAI
(which uses automatic prefix caching without explicit cache_control).
"""
if not self._prefix_cache_enabled:
return None
base_url_lower = self.base_url.lower()
for provider in self.CACHE_CONTROL_PROVIDERS:
if provider in base_url_lower:
return provider
# Native OpenAI: automatic prefix caching, no need for cache_control
return None

@staticmethod
def _to_structured_content(
content: Any,
add_cache_control: bool = False,
provider: Optional[str] = None,
) -> Any:
"""
Convert message content to structured content blocks for prefix caching.

This method is idempotent: already-structured content is returned as-is
(with optional cache_control addition for dashscope/anthropic).

Args:
content: Original content (str or list)
add_cache_control: Whether to add cache_control to text blocks

Returns:
Structured content list or original content if not applicable
"""
if not add_cache_control:
return content

# Case 1: plain string -> wrap in structured block
if isinstance(content, str):
block: Dict[str, Any] = {'type': 'text', 'text': content}
if provider in {'dashscope', 'anthropic'}:
block['cache_control'] = {'type': 'ephemeral'}
return [block]

# Case 2: already a list (multimodal or pre-structured)
if isinstance(content, list):
# Add cache_control to text blocks that don't have it
new_list = []
for item in content:
if (isinstance(item, dict) and item.get('type') == 'text'
and 'cache_control' not in item):
new_item = dict(item)
new_item['cache_control'] = {'type': 'ephemeral'}
new_list.append(new_item)
else:
new_list.append(item)
return new_list

# Other types: return as-is
return content

def format_tools(self,
tools: Optional[List[Tool]] = None
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -143,6 +222,34 @@ def _call_llm(self,
return self.client.chat.completions.create(
model=self.model, messages=messages, tools=tools, **kwargs)

@staticmethod
def _extract_cache_info(usage_obj: Any) -> tuple:
"""
Extract cache info from an OpenAI-compatible usage object.

Returns:
tuple: (cached_tokens, cache_creation_input_tokens)
- cached_tokens: tokens that hit existing cache
- cache_creation_input_tokens: tokens used to create new cache (explicit cache only)

OpenAI/DashScope format: usage.prompt_tokens_details.{cached_tokens, cache_creation_input_tokens}
"""
if not usage_obj:
return 0, 0
details = getattr(usage_obj, 'prompt_tokens_details', None)
if details is None and isinstance(usage_obj, dict):
details = usage_obj.get('prompt_tokens_details')
if details is None:
return 0, 0
if isinstance(details, dict):
cached = int(details.get('cached_tokens', 0) or 0)
created = int(details.get('cache_creation_input_tokens', 0) or 0)
else:
cached = int(getattr(details, 'cached_tokens', 0) or 0)
created = int(
getattr(details, 'cache_creation_input_tokens', 0) or 0)
return cached, created

def _merge_stream_message(self, pre_message_chunk: Optional[Message],
message_chunk: Message) -> Optional[Message]:
"""Merges a new chunk of message into the previous chunks during streaming.
Expand Down Expand Up @@ -227,6 +334,10 @@ def _stream_continue_generate(self,
try:
next_chunk = next(completion)
message.prompt_tokens += next_chunk.usage.prompt_tokens
cached, created = self._extract_cache_info(
getattr(next_chunk, 'usage', None))
message.cached_tokens += cached
message.cache_creation_input_tokens += created
message.completion_tokens += next_chunk.usage.completion_tokens
except (StopIteration, AttributeError):
# The stream may end without a final usage chunk, which is acceptable.
Expand Down Expand Up @@ -323,13 +434,17 @@ def _format_output_message(completion) -> Message:
tool_name=tool_call.function.name) for idx, tool_call in
enumerate(completion.choices[0].message.tool_calls)
]
cached, created = OpenAI._extract_cache_info(
getattr(completion, 'usage', None))
return Message(
role='assistant',
content=content,
reasoning_content=reasoning_content,
tool_calls=tool_calls,
id=completion.id,
prompt_tokens=completion.usage.prompt_tokens,
cached_tokens=cached,
cache_creation_input_tokens=created,
completion_tokens=completion.usage.completion_tokens)

@staticmethod
Expand All @@ -343,6 +458,9 @@ def _merge_partial_message(messages: List[Message], new_message: Message):
messages[-1].reasoning_content += new_message.reasoning_content
messages[-1].content += new_message.content
messages[-1].prompt_tokens += new_message.prompt_tokens
messages[-1].cached_tokens += new_message.cached_tokens
messages[
-1].cache_creation_input_tokens += new_message.cache_creation_input_tokens
messages[-1].completion_tokens += new_message.completion_tokens
if new_message.tool_calls:
if messages[-1].tool_calls:
Expand Down Expand Up @@ -432,12 +550,44 @@ def _format_input_message(self,
Returns:
List[Dict[str, Any]]: List of dictionaries compatible with OpenAI's input format.
"""
# Determine if we need to add cache_control (for dashscope/anthropic)
add_cache_control = self._prefix_cache_provider is not None

# Determine which message index should have cache_control (the last matching one)
cache_indice = None
if self._prefix_cache_enabled and add_cache_control:
cache_indices = set()
# Check for 'last_message' special value
if 'last_message' in self._prefix_cache_roles and messages:
cache_indices.add(len(messages) - 1)
# Check for role-based caching
role_cache = self._prefix_cache_roles - {'last_message'}
for idx, msg in enumerate(messages):
msg_role = msg.role if isinstance(msg, Message) else msg.get(
'role', '')
if msg_role in role_cache:
cache_indices.add(idx)
cache_indice = max(cache_indices) if cache_indices else None

openai_messages = []
for message in messages:
for idx, message in enumerate(messages):
if isinstance(message, Message):
if isinstance(message.content, str):
message.content = message.content.strip()
message = message.to_dict_clean()
else:
message = dict(message)

content = message.get('content', '')
if isinstance(content, str):
content = content.strip()

# Apply prefix cache structured content transformation
if cache_indice is not None and idx == cache_indice:
content = self._to_structured_content(
content,
add_cache_control=True,
provider=self._prefix_cache_provider)

message = {
key: value.strip() if isinstance(value, str) else value
Expand All @@ -446,6 +596,7 @@ def _format_input_message(self,
}
if 'content' not in message:
message['content'] = ''
message['content'] = content if content else ''
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression content if content else '' can cause issues when content is an empty list ([]), which is a valid structured content format. In that case, it would be incorrectly converted to an empty string ''. It should be assigned directly to preserve the data type.

            message['content'] = content


openai_messages.append(message)

Expand Down
6 changes: 6 additions & 0 deletions ms_agent/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class Message:
# usage
completion_tokens: int = 0
prompt_tokens: int = 0

# tokens that hit existing cache (billed at reduced rate like 0.1x)
cached_tokens: int = 0
# tokens used to create new cache (explicit cache only, billed at higher rate like 1.25x)
cache_creation_input_tokens: int = 0

api_calls: int = 1

def to_dict(self):
Expand Down
1 change: 1 addition & 0 deletions ms_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .filesystem_tool import FileSystemTool
from .mcp_client import MCPClient
from .split_task import SplitTask
from .todolist_tool import TodoListTool
from .tool_manager import ToolManager
Loading
Loading