diff --git a/os_computer_use/llm_provider.py b/os_computer_use/llm_provider.py index 4482249..a937a24 100644 --- a/os_computer_use/llm_provider.py +++ b/os_computer_use/llm_provider.py @@ -22,6 +22,29 @@ def parse_json(s): return None +def extract_json_objects(s): + """Extract all balanced JSON objects from a string.""" + objects = [] + brace_level = 0 + start_index = None + for i, char in enumerate(s): + if char == "{": + if brace_level == 0: + start_index = i + brace_level += 1 + elif char == "}": + brace_level -= 1 + if brace_level == 0 and start_index is not None: + candidate = s[start_index : i + 1] + try: + obj = json.loads(candidate) + objects.append(obj) + except json.JSONDecodeError: + pass + start_index = None + return objects + + class LLMProvider: """ The LLM provider is used to make calls to an LLM given a provider and model name, with optional tool use support @@ -52,6 +75,13 @@ def create_function_schema(self, definitions): properties[param_name] = {"type": "string", "description": param_desc} required.append(param_name) + # Add a dummy property if no parameters are provided, because providers like Gemini require a non-empty "properties" object. + if not properties: + properties["noop"] = { + "type": "string", + "description": "Dummy parameter for function with no parameters.", + } + function_def = self.create_function_def(name, details, properties, required) functions.append(function_def) @@ -142,16 +172,15 @@ def call(self, messages, functions=None): # Sometimes, function calls are returned unparsed by the inference provider. This code parses them manually. if message.content and not tool_calls: - tool_call_matches = re.search(r"\{.*\}", message.content) - if tool_call_matches: - tool_call = parse_json(tool_call_matches.group(0)) - # Some models use "arguments" as the key instead of "parameters" - parameters = tool_call.get("parameters", tool_call.get("arguments")) - if tool_call.get("name") and parameters: + json_objs = extract_json_objects(message.content) + for obj in json_objs: + parameters = obj.get("parameters", obj.get("arguments")) + if obj.get("name") and parameters is not None: combined_tool_calls.append( - self.create_tool_call(tool_call.get("name"), parameters) + self.create_tool_call(obj.get("name"), parameters) ) - return None, combined_tool_calls + if combined_tool_calls: + return None, combined_tool_calls return message.content, combined_tool_calls