diff --git a/src/mcp_agent/workflows/llm/augmented_llm_openai.py b/src/mcp_agent/workflows/llm/augmented_llm_openai.py index 6e2562eac..690ab7956 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_openai.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_openai.py @@ -68,7 +68,14 @@ def __init__(self, *args, **kwargs): if hasattr(self.context.config.openai, "reasoning_effort"): self._reasoning_effort = self.context.config.openai.reasoning_effort - self._reasoning = lambda model : model.startswith(("o1","o3","o4")) + self._reasoning = lambda model: model.startswith(("o1", "o3", "o4")) + + self._strict_tool_validation = True # default strict validation + if self.context and self.context.config and self.context.config.openai: + if hasattr(self.context.config.openai, "strict_tool_validation"): + self._strict_tool_validation = ( + self.context.config.openai.strict_tool_validation + ) if self._reasoning(chosen_model): self.logger.info( @@ -142,7 +149,7 @@ async def generate(self, message, request_params: RequestParams | None = None): "name": tool.name, "description": tool.description, "parameters": tool.inputSchema, - # TODO: saqadri - determine if we should specify "strict" to True by default + "strict": self._strict_tool_validation, }, ) for tool in response.tools @@ -163,10 +170,8 @@ async def generate(self, message, request_params: RequestParams | None = None): if self._reasoning(model): arguments = { **arguments, - # DEPRECATED: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens # "max_tokens": params.maxTokens, - "max_completion_tokens": params.maxTokens, "reasoning_effort": self._reasoning_effort, }