diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 4a29c0b7d..30cfb2a02 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -612,46 +612,53 @@ def timestamp(self) -> datetime: @staticmethod def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None: - output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings') - if output_json: - for output_tool in output_tools.values(): - # NOTE: Additional verification to prevent JSON validation to crash - # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. - # Example with BaseModel and required fields. - if not MistralStreamedResponse._validate_required_json_schema( - output_json, output_tool.parameters_json_schema - ): - continue - - # The following part_id will be thrown away - return ToolCallPart(tool_name=output_tool.name, args=output_json) + output_json = pydantic_core.from_json(text, allow_partial='trailing-strings') + if not output_json: + return None + for output_tool in output_tools.values(): + # NOTE: Additional verification to prevent JSON validation to crash + # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. + if not MistralStreamedResponse._validate_required_json_schema( + output_json, output_tool.parameters_json_schema + ): + continue + return ToolCallPart(tool_name=output_tool.name, args=output_json) + return None # Added fallback to ensure None is returned if nothing matches @staticmethod def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool: """Validate that all required parameters in the JSON schema are present in the JSON dictionary.""" - required_params = json_schema.get('required', []) + required_params = json_schema.get('required') + if not required_params: + return True properties = json_schema.get('properties', {}) for param in required_params: if param not in json_dict: return False - param_schema = properties.get(param, {}) + param_schema = properties.get(param) + if not param_schema: + return False param_type = param_schema.get('type') - param_items_type = param_schema.get('items', {}).get('type') - - if param_type == 'array' and param_items_type: - if not isinstance(json_dict[param], list): + if param_type == 'array': + value = json_dict[param] + if not isinstance(value, list): + return False + param_items_type = param_schema.get('items', {}).get('type') + if param_items_type: + target_cls = VALID_JSON_TYPE_MAPPING[param_items_type] + for item in value: + if not isinstance(item, target_cls): + return False + elif param_type: + target_cls = VALID_JSON_TYPE_MAPPING[param_type] + if not isinstance(json_dict[param], target_cls): return False - for item in json_dict[param]: - if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]): - return False - elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]): - return False - if isinstance(json_dict[param], dict) and 'properties' in param_schema: - nested_schema = param_schema - if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema): + value = json_dict[param] + if isinstance(value, dict) and 'properties' in param_schema: + if not MistralStreamedResponse._validate_required_json_schema(value, param_schema): return False return True