Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.11.13",
"fastapi>=0.115.6",
"instructor>=1.7.9",
"jsonref>=1.1.0",
"mcp>=1.10.1",
"numpy>=2.1.3",
Expand Down
2 changes: 2 additions & 0 deletions src/mcp_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import threading
import warnings

from httpx import URL
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand Down Expand Up @@ -205,6 +206,7 @@ class AnthropicSettings(BaseSettings, VertexAIMixin, BedrockMixin):
"provider", "ANTHROPIC_PROVIDER", "anthropic__provider"
),
)
base_url: str | URL | None = Field(default=None)

model_config = SettingsConfigDict(
env_prefix="ANTHROPIC_",
Expand Down
5 changes: 5 additions & 0 deletions src/mcp_agent/workflows/llm/augmented_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ class RequestParams(CreateMessageRequestParams):
This is used to stably identify the user in the LLM provider's logs.
"""

strict: bool = False
"""
Whether models that support strict mode should strictly enforce the response schema.
"""


class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
"""Protocol defining the interface for augmented LLMs"""
Expand Down
161 changes: 66 additions & 95 deletions src/mcp_agent/workflows/llm/augmented_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from mcp_agent.tracing.telemetry import get_tracer, is_otel_serializable, telemetry
from mcp_agent.tracing.token_tracking_decorator import track_tokens
from mcp_agent.utils.common import ensure_serializable, typed_dict_extras, to_string
from mcp_agent.utils.pydantic_type_serializer import serialize_model, deserialize_model

from mcp_agent.workflows.llm.augmented_llm import (
AugmentedLLM,
ModelT,
Expand Down Expand Up @@ -83,15 +83,6 @@ class RequestCompletionRequest(BaseModel):
payload: dict


class RequestStructuredCompletionRequest(BaseModel):
config: AnthropicSettings
params: RequestParams
response_model: Type[ModelT] | None = None
serialized_response_model: str | None = None
response_str: str
model: str


def create_anthropic_instance(settings: AnthropicSettings):
"""Select and initialise the appropriate anthropic client instance based on settings"""
if settings.provider == "bedrock":
Expand Down Expand Up @@ -419,68 +410,86 @@ async def generate_structured(
response_model: Type[ModelT],
request_params: RequestParams | None = None,
) -> ModelT:
# First we invoke the LLM to generate a string response
# We need to do this in a two-step process because Instructor doesn't
# know how to invoke MCP tools via call_tool, so we'll handle all the
# processing first and then pass the final response through Instructor
# Use Anthropic's native structured output via a forced tool call carrying JSON input
import json

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.generate_structured"
) as span:
span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
self._annotate_span_for_generation_message(span, message)

response = await self.generate_str(
message=message,
request_params=request_params,
)

params = self.get_request_params(request_params)

if self.context.tracing_enabled:
AugmentedLLM.annotate_span_with_request_params(span, params)

model = await self.select_model(params)
span.set_attribute(GEN_AI_REQUEST_MODEL, model)

span.set_attribute("response_model", response_model.__name__)

serialized_response_model: str | None = None

if self.executor and self.executor.execution_engine == "temporal":
# Serialize the response model to a string
serialized_response_model = serialize_model(response_model)

structured_response = await self.executor.execute(
AnthropicCompletionTasks.request_structured_completion_task,
RequestStructuredCompletionRequest(
config=self.context.config.anthropic,
params=params,
response_model=response_model
if not serialized_response_model
else None,
serialized_response_model=serialized_response_model,
response_str=response,
model=model,
),
model_name = (
await self.select_model(params) or self.default_request_params.model
)
span.set_attribute(GEN_AI_REQUEST_MODEL, model_name)

# TODO: saqadri (MAC) - fix request_structured_completion_task to return ensure_serializable
# Convert dict back to the proper model instance if needed
if isinstance(structured_response, dict):
structured_response = response_model.model_validate(structured_response)
# Convert message(s) to Anthropic format
messages: List[MessageParam] = []
if params.use_history:
messages.extend(self.history.get())
messages.extend(
AnthropicConverter.convert_mixed_messages_to_anthropic(message)
)

if self.context.tracing_enabled:
try:
span.set_attribute(
"structured_response_json",
structured_response.model_dump_json(),
)
# pylint: disable=broad-exception-caught
except Exception:
span.set_attribute("unstructured_response", response)
# Define a single tool that matches the Pydantic schema
schema = response_model.model_json_schema()
tools: List[ToolParam] = [
{
"name": "return_structured_output",
"description": "Return the response in the required JSON format",
"input_schema": schema,
}
]

args = {
"model": model_name,
"messages": messages,
"system": self.instruction or params.systemPrompt,
"tools": tools,
"tool_choice": {"type": "tool", "name": "return_structured_output"},
}
if params.maxTokens is not None:
args["max_tokens"] = params.maxTokens
if params.stopSequences:
args["stop_sequences"] = params.stopSequences

# Call Anthropic directly (one-turn streaming for consistency)
base_url = None
if self.context and self.context.config and self.context.config.anthropic:
base_url = self.context.config.anthropic.base_url
api_key = self.context.config.anthropic.api_key
client = AsyncAnthropic(api_key=api_key, base_url=base_url)
else:
client = AsyncAnthropic()
Comment on lines +464 to +469
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AsyncAnthropic client should be used within an async context manager to ensure proper resource cleanup. Currently, the client is instantiated but never properly closed, which could lead to connection leaks. Consider refactoring to:

async with AsyncAnthropic(api_key=api_key, base_url=base_url) as client:
    async with client.messages.stream(**args) as stream:
        final = await stream.get_final_message()

This ensures the client is properly closed after use, preventing potential resource leaks in production environments.

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


async with client:
async with client.messages.stream(**args) as stream:
final = await stream.get_final_message()

# Extract tool_use input and validate
for block in final.content:
if (
getattr(block, "type", None) == "tool_use"
and getattr(block, "name", "") == "return_structured_output"
):
data = getattr(block, "input", None)
try:
if isinstance(data, str):
return response_model.model_validate(json.loads(data))
return response_model.model_validate(data)
except Exception:
# Fallthrough to error
break
Comment on lines +483 to +488
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad except Exception: clause may obscure important errors during structured output parsing. Consider catching specific exceptions like json.JSONDecodeError or pydantic.ValidationError instead. This would provide clearer error messages and help distinguish between data format issues versus more serious system problems. The current approach makes debugging challenging as it silently falls through to a generic error message.

Suggested change
if isinstance(data, str):
return response_model.model_validate(json.loads(data))
return response_model.model_validate(data)
except Exception:
# Fallthrough to error
break
if isinstance(data, str):
return response_model.model_validate(json.loads(data))
return response_model.model_validate(data)
except json.JSONDecodeError:
# JSON parsing error - invalid JSON format
logger.error("Failed to parse JSON response")
break
except pydantic.ValidationError:
# Validation error - JSON structure doesn't match expected model
logger.error("Response data failed validation against model")
break

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


return structured_response
raise ValueError(
"Failed to obtain structured output from Anthropic response"
)
Comment on lines +476 to +492
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling in this section could be improved for better diagnostics. Currently, the code uses getattr() with default values which can mask actual errors in the API response structure. If block.input is None or block.name doesn't match expectations, the code will silently continue to the generic ValueError at the end rather than providing specific information about what went wrong.

Consider adding more specific error handling to distinguish between different failure modes:

for block in final.content:
    block_type = getattr(block, "type", None)
    if block_type != "tool_use":
        continue
        
    block_name = getattr(block, "name", None)
    if block_name != "return_structured_output":
        continue
        
    data = getattr(block, "input", None)
    if data is None:
        raise ValueError("Tool use block found but input data is missing")
        
    try:
        if isinstance(data, str):
            return response_model.model_validate(json.loads(data))
        return response_model.model_validate(data)
    except Exception as e:
        raise ValueError(f"Failed to validate response data: {str(e)}") from e

raise ValueError("No structured output tool use found in Anthropic response")
Suggested change
for block in final.content:
if (
getattr(block, "type", None) == "tool_use"
and getattr(block, "name", "") == "return_structured_output"
):
data = getattr(block, "input", None)
try:
if isinstance(data, str):
return response_model.model_validate(json.loads(data))
return response_model.model_validate(data)
except Exception:
# Fallthrough to error
break
return structured_response
raise ValueError(
"Failed to obtain structured output from Anthropic response"
)
for block in final.content:
block_type = getattr(block, "type", None)
if block_type != "tool_use":
continue
block_name = getattr(block, "name", None)
if block_name != "return_structured_output":
continue
data = getattr(block, "input", None)
if data is None:
raise ValueError("Tool use block found but input data is missing")
try:
if isinstance(data, str):
return response_model.model_validate(json.loads(data))
return response_model.model_validate(data)
except Exception as e:
raise ValueError(f"Failed to validate response data: {str(e)}") from e
raise ValueError("No structured output tool use found in Anthropic response")

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


@classmethod
def convert_message_to_message_param(
Expand Down Expand Up @@ -770,44 +779,6 @@ async def request_completion_task(
response = ensure_serializable(response)
return response

@staticmethod
@workflow_task
@telemetry.traced()
async def request_structured_completion_task(
request: RequestStructuredCompletionRequest,
):
"""
Request a structured completion using Instructor's Anthropic API.
"""
import instructor

if request.response_model:
response_model = request.response_model
elif request.serialized_response_model:
response_model = deserialize_model(request.serialized_response_model)
else:
raise ValueError(
"Either response_model or serialized_response_model must be provided for structured completion."
)

# We pass the text through instructor to extract structured data
client = instructor.from_anthropic(create_anthropic_instance(request.config))

# Extract structured data from natural language without blocking the loop
loop = asyncio.get_running_loop()
structured_response = await loop.run_in_executor(
None,
functools.partial(
client.chat.completions.create,
model=request.model,
response_model=response_model,
messages=[{"role": "user", "content": request.response_str}],
max_tokens=request.params.maxTokens,
),
)

return structured_response


class AnthropicMCPTypeConverter(ProviderToMCPConverter[MessageParam, Message]):
"""
Expand Down
39 changes: 33 additions & 6 deletions src/mcp_agent/workflows/llm/augmented_llm_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import json
from typing import Any, Iterable, Optional, Type, Union
from azure.core.exceptions import HttpResponseError
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import (
ChatCompletions,
Expand Down Expand Up @@ -351,6 +352,7 @@ async def generate_structured(
name=response_model.__name__,
description=response_model.__doc__,
schema=json_schema,
strict=request_params.strict,
)
request_params.metadata = metadata

Expand All @@ -362,7 +364,7 @@ async def generate_structured(

@classmethod
def convert_message_to_message_param(
cls, message: ResponseMessage, **kwargs
cls, message: ResponseMessage
) -> AssistantMessage:
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
assistant_message = AssistantMessage(
Expand Down Expand Up @@ -539,12 +541,37 @@ async def request_completion_task(
),
)

payload = request.payload
# Offload sync SDK call to a thread to avoid blocking the event loop
payload = request.payload.copy()
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, functools.partial(azure_client.complete, **payload)
)

try:
response = await loop.run_in_executor(
None, functools.partial(azure_client.complete, **payload)
)
except HttpResponseError as e:
logger = get_logger(__name__)

if e.status_code != 400:
logger.error(f"Azure API call failed: {e}")
raise

logger.warning(
f"Initial Azure API call failed: {e}. Retrying with fallback parameters."
)

# Create a new payload with fallback values for commonly problematic parameters
fallback_payload = {**payload, "max_tokens": None, "temperature": 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded fallback parameters may lead to unexpected behavior. Setting temperature=1 could significantly alter the model's output characteristics compared to what was originally requested, potentially disrupting applications that rely on specific temperature settings. Consider either preserving the original temperature value or implementing a more conservative fallback strategy that maintains output consistency with the original request parameters.

Suggested change
fallback_payload = {**payload, "max_tokens": None, "temperature": 1}
fallback_payload = {**payload, "max_tokens": None}

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Azure API fallback logic sets max_tokens: None, but Azure typically requires max_tokens to be a positive integer. This could cause the fallback request to also fail with a 400 error, defeating the purpose of the retry mechanism. Consider either removing max_tokens entirely from the fallback payload or setting it to a safe default value (like 1024) instead of None.

Suggested change
fallback_payload = {**payload, "max_tokens": None, "temperature": 1}
fallback_payload = {**payload, "max_tokens": 1024, "temperature": 1}

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


try:
response = await loop.run_in_executor(
None, functools.partial(azure_client.complete, **fallback_payload)
)
except Exception as retry_error:
# If retry also fails, raise a more informative error
raise RuntimeError(
f"Azure API call failed even with fallback parameters. "
f"Original error: {e}. Retry error: {retry_error}"
) from retry_error
return response


Expand Down
Loading
Loading