Skip to content

Commit 130672d

Browse files
authored
Get structured outputs using LLM native APIs (#418)
1 parent df6fa89 commit 130672d

File tree

10 files changed

+393
-340
lines changed

10 files changed

+393
-340
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ requires-python = ">=3.10"
1616
dependencies = [
1717
"aiohttp>=3.11.13",
1818
"fastapi>=0.115.6",
19-
"instructor>=1.7.9",
2019
"jsonref>=1.1.0",
2120
"mcp>=1.10.1",
2221
"numpy>=2.1.3",

src/mcp_agent/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import threading
1111
import warnings
1212

13+
from httpx import URL
1314
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator
1415
from pydantic_settings import BaseSettings, SettingsConfigDict
1516

@@ -205,6 +206,7 @@ class AnthropicSettings(BaseSettings, VertexAIMixin, BedrockMixin):
205206
"provider", "ANTHROPIC_PROVIDER", "anthropic__provider"
206207
),
207208
)
209+
base_url: str | URL | None = Field(default=None)
208210

209211
model_config = SettingsConfigDict(
210212
env_prefix="ANTHROPIC_",

src/mcp_agent/workflows/llm/augmented_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ class RequestParams(CreateMessageRequestParams):
168168
This is used to stably identify the user in the LLM provider's logs.
169169
"""
170170

171+
strict: bool = False
172+
"""
173+
Whether models that support strict mode should strictly enforce the response schema.
174+
"""
175+
171176

172177
class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
173178
"""Protocol defining the interface for augmented LLMs"""

src/mcp_agent/workflows/llm/augmented_llm_anthropic.py

Lines changed: 66 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from mcp_agent.tracing.telemetry import get_tracer, is_otel_serializable, telemetry
4949
from mcp_agent.tracing.token_tracking_decorator import track_tokens
5050
from mcp_agent.utils.common import ensure_serializable, typed_dict_extras, to_string
51-
from mcp_agent.utils.pydantic_type_serializer import serialize_model, deserialize_model
51+
5252
from mcp_agent.workflows.llm.augmented_llm import (
5353
AugmentedLLM,
5454
ModelT,
@@ -83,15 +83,6 @@ class RequestCompletionRequest(BaseModel):
8383
payload: dict
8484

8585

86-
class RequestStructuredCompletionRequest(BaseModel):
87-
config: AnthropicSettings
88-
params: RequestParams
89-
response_model: Type[ModelT] | None = None
90-
serialized_response_model: str | None = None
91-
response_str: str
92-
model: str
93-
94-
9586
def create_anthropic_instance(settings: AnthropicSettings):
9687
"""Select and initialise the appropriate anthropic client instance based on settings"""
9788
if settings.provider == "bedrock":
@@ -419,68 +410,86 @@ async def generate_structured(
419410
response_model: Type[ModelT],
420411
request_params: RequestParams | None = None,
421412
) -> ModelT:
422-
# First we invoke the LLM to generate a string response
423-
# We need to do this in a two-step process because Instructor doesn't
424-
# know how to invoke MCP tools via call_tool, so we'll handle all the
425-
# processing first and then pass the final response through Instructor
413+
# Use Anthropic's native structured output via a forced tool call carrying JSON input
414+
import json
415+
426416
tracer = get_tracer(self.context)
427417
with tracer.start_as_current_span(
428418
f"{self.__class__.__name__}.{self.name}.generate_structured"
429419
) as span:
430420
span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
431421
self._annotate_span_for_generation_message(span, message)
432422

433-
response = await self.generate_str(
434-
message=message,
435-
request_params=request_params,
436-
)
437-
438423
params = self.get_request_params(request_params)
439-
440424
if self.context.tracing_enabled:
441425
AugmentedLLM.annotate_span_with_request_params(span, params)
442426

443-
model = await self.select_model(params)
444-
span.set_attribute(GEN_AI_REQUEST_MODEL, model)
445-
446-
span.set_attribute("response_model", response_model.__name__)
447-
448-
serialized_response_model: str | None = None
449-
450-
if self.executor and self.executor.execution_engine == "temporal":
451-
# Serialize the response model to a string
452-
serialized_response_model = serialize_model(response_model)
453-
454-
structured_response = await self.executor.execute(
455-
AnthropicCompletionTasks.request_structured_completion_task,
456-
RequestStructuredCompletionRequest(
457-
config=self.context.config.anthropic,
458-
params=params,
459-
response_model=response_model
460-
if not serialized_response_model
461-
else None,
462-
serialized_response_model=serialized_response_model,
463-
response_str=response,
464-
model=model,
465-
),
427+
model_name = (
428+
await self.select_model(params) or self.default_request_params.model
466429
)
430+
span.set_attribute(GEN_AI_REQUEST_MODEL, model_name)
467431

468-
# TODO: saqadri (MAC) - fix request_structured_completion_task to return ensure_serializable
469-
# Convert dict back to the proper model instance if needed
470-
if isinstance(structured_response, dict):
471-
structured_response = response_model.model_validate(structured_response)
432+
# Convert message(s) to Anthropic format
433+
messages: List[MessageParam] = []
434+
if params.use_history:
435+
messages.extend(self.history.get())
436+
messages.extend(
437+
AnthropicConverter.convert_mixed_messages_to_anthropic(message)
438+
)
472439

473-
if self.context.tracing_enabled:
474-
try:
475-
span.set_attribute(
476-
"structured_response_json",
477-
structured_response.model_dump_json(),
478-
)
479-
# pylint: disable=broad-exception-caught
480-
except Exception:
481-
span.set_attribute("unstructured_response", response)
440+
# Define a single tool that matches the Pydantic schema
441+
schema = response_model.model_json_schema()
442+
tools: List[ToolParam] = [
443+
{
444+
"name": "return_structured_output",
445+
"description": "Return the response in the required JSON format",
446+
"input_schema": schema,
447+
}
448+
]
449+
450+
args = {
451+
"model": model_name,
452+
"messages": messages,
453+
"system": self.instruction or params.systemPrompt,
454+
"tools": tools,
455+
"tool_choice": {"type": "tool", "name": "return_structured_output"},
456+
}
457+
if params.maxTokens is not None:
458+
args["max_tokens"] = params.maxTokens
459+
if params.stopSequences:
460+
args["stop_sequences"] = params.stopSequences
461+
462+
# Call Anthropic directly (one-turn streaming for consistency)
463+
base_url = None
464+
if self.context and self.context.config and self.context.config.anthropic:
465+
base_url = self.context.config.anthropic.base_url
466+
api_key = self.context.config.anthropic.api_key
467+
client = AsyncAnthropic(api_key=api_key, base_url=base_url)
468+
else:
469+
client = AsyncAnthropic()
470+
471+
async with client:
472+
async with client.messages.stream(**args) as stream:
473+
final = await stream.get_final_message()
474+
475+
# Extract tool_use input and validate
476+
for block in final.content:
477+
if (
478+
getattr(block, "type", None) == "tool_use"
479+
and getattr(block, "name", "") == "return_structured_output"
480+
):
481+
data = getattr(block, "input", None)
482+
try:
483+
if isinstance(data, str):
484+
return response_model.model_validate(json.loads(data))
485+
return response_model.model_validate(data)
486+
except Exception:
487+
# Fallthrough to error
488+
break
482489

483-
return structured_response
490+
raise ValueError(
491+
"Failed to obtain structured output from Anthropic response"
492+
)
484493

485494
@classmethod
486495
def convert_message_to_message_param(
@@ -770,44 +779,6 @@ async def request_completion_task(
770779
response = ensure_serializable(response)
771780
return response
772781

773-
@staticmethod
774-
@workflow_task
775-
@telemetry.traced()
776-
async def request_structured_completion_task(
777-
request: RequestStructuredCompletionRequest,
778-
):
779-
"""
780-
Request a structured completion using Instructor's Anthropic API.
781-
"""
782-
import instructor
783-
784-
if request.response_model:
785-
response_model = request.response_model
786-
elif request.serialized_response_model:
787-
response_model = deserialize_model(request.serialized_response_model)
788-
else:
789-
raise ValueError(
790-
"Either response_model or serialized_response_model must be provided for structured completion."
791-
)
792-
793-
# We pass the text through instructor to extract structured data
794-
client = instructor.from_anthropic(create_anthropic_instance(request.config))
795-
796-
# Extract structured data from natural language without blocking the loop
797-
loop = asyncio.get_running_loop()
798-
structured_response = await loop.run_in_executor(
799-
None,
800-
functools.partial(
801-
client.chat.completions.create,
802-
model=request.model,
803-
response_model=response_model,
804-
messages=[{"role": "user", "content": request.response_str}],
805-
max_tokens=request.params.maxTokens,
806-
),
807-
)
808-
809-
return structured_response
810-
811782

812783
class AnthropicMCPTypeConverter(ProviderToMCPConverter[MessageParam, Message]):
813784
"""

src/mcp_agent/workflows/llm/augmented_llm_azure.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import json
44
from typing import Any, Iterable, Optional, Type, Union
5+
from azure.core.exceptions import HttpResponseError
56
from azure.ai.inference import ChatCompletionsClient
67
from azure.ai.inference.models import (
78
ChatCompletions,
@@ -351,6 +352,7 @@ async def generate_structured(
351352
name=response_model.__name__,
352353
description=response_model.__doc__,
353354
schema=json_schema,
355+
strict=request_params.strict,
354356
)
355357
request_params.metadata = metadata
356358

@@ -362,7 +364,7 @@ async def generate_structured(
362364

363365
@classmethod
364366
def convert_message_to_message_param(
365-
cls, message: ResponseMessage, **kwargs
367+
cls, message: ResponseMessage
366368
) -> AssistantMessage:
367369
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
368370
assistant_message = AssistantMessage(
@@ -539,12 +541,37 @@ async def request_completion_task(
539541
),
540542
)
541543

542-
payload = request.payload
543-
# Offload sync SDK call to a thread to avoid blocking the event loop
544+
payload = request.payload.copy()
544545
loop = asyncio.get_running_loop()
545-
response = await loop.run_in_executor(
546-
None, functools.partial(azure_client.complete, **payload)
547-
)
546+
547+
try:
548+
response = await loop.run_in_executor(
549+
None, functools.partial(azure_client.complete, **payload)
550+
)
551+
except HttpResponseError as e:
552+
logger = get_logger(__name__)
553+
554+
if e.status_code != 400:
555+
logger.error(f"Azure API call failed: {e}")
556+
raise
557+
558+
logger.warning(
559+
f"Initial Azure API call failed: {e}. Retrying with fallback parameters."
560+
)
561+
562+
# Create a new payload with fallback values for commonly problematic parameters
563+
fallback_payload = {**payload, "max_tokens": None, "temperature": 1}
564+
565+
try:
566+
response = await loop.run_in_executor(
567+
None, functools.partial(azure_client.complete, **fallback_payload)
568+
)
569+
except Exception as retry_error:
570+
# If retry also fails, raise a more informative error
571+
raise RuntimeError(
572+
f"Azure API call failed even with fallback parameters. "
573+
f"Original error: {e}. Retry error: {retry_error}"
574+
) from retry_error
548575
return response
549576

550577

0 commit comments

Comments
 (0)