Skip to content

⚡️ Speed up method MistralStreamedResponse._try_get_output_tool_from_text by 5% #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: try-refinement
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,46 +612,53 @@ def timestamp(self) -> datetime:

@staticmethod
def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
if output_json:
for output_tool in output_tools.values():
# NOTE: Additional verification to prevent JSON validation to crash
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
# Example with BaseModel and required fields.
if not MistralStreamedResponse._validate_required_json_schema(
output_json, output_tool.parameters_json_schema
):
continue

# The following part_id will be thrown away
return ToolCallPart(tool_name=output_tool.name, args=output_json)
output_json = pydantic_core.from_json(text, allow_partial='trailing-strings')
if not output_json:
return None
for output_tool in output_tools.values():
# NOTE: Additional verification to prevent JSON validation to crash
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
if not MistralStreamedResponse._validate_required_json_schema(
output_json, output_tool.parameters_json_schema
):
continue
return ToolCallPart(tool_name=output_tool.name, args=output_json)
return None # Added fallback to ensure None is returned if nothing matches

@staticmethod
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
required_params = json_schema.get('required', [])
required_params = json_schema.get('required')
if not required_params:
return True
properties = json_schema.get('properties', {})

for param in required_params:
if param not in json_dict:
return False

param_schema = properties.get(param, {})
param_schema = properties.get(param)
if not param_schema:
return False
param_type = param_schema.get('type')
param_items_type = param_schema.get('items', {}).get('type')

if param_type == 'array' and param_items_type:
if not isinstance(json_dict[param], list):
if param_type == 'array':
value = json_dict[param]
if not isinstance(value, list):
return False
param_items_type = param_schema.get('items', {}).get('type')
if param_items_type:
target_cls = VALID_JSON_TYPE_MAPPING[param_items_type]
for item in value:
if not isinstance(item, target_cls):
return False
elif param_type:
target_cls = VALID_JSON_TYPE_MAPPING[param_type]
if not isinstance(json_dict[param], target_cls):
return False
for item in json_dict[param]:
if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
return False
elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
return False

if isinstance(json_dict[param], dict) and 'properties' in param_schema:
nested_schema = param_schema
if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
value = json_dict[param]
if isinstance(value, dict) and 'properties' in param_schema:
if not MistralStreamedResponse._validate_required_json_schema(value, param_schema):
return False

return True
Expand Down
Loading