|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from dataclasses import dataclass |
| 4 | +from functools import partial |
| 5 | +from typing import Any, Dict, List, Optional, Type |
| 6 | + |
| 7 | +import litellm |
| 8 | +from litellm import completion |
| 9 | +from openai.types.chat import ChatCompletion as OpenAIChatCompletion |
| 10 | + |
| 11 | +from agentlab.llm.base_api import BaseModelArgs |
| 12 | +from agentlab.llm.response_api import ( |
| 13 | + AgentlabAction, |
| 14 | + APIPayload, |
| 15 | + BaseModelWithPricing, |
| 16 | + LLMOutput, |
| 17 | + Message, |
| 18 | + MessageBuilder, |
| 19 | + OpenAIChatCompletionAPIMessageBuilder, |
| 20 | + ToolCall, |
| 21 | + ToolCalls, |
| 22 | +) |
| 23 | + |
| 24 | +litellm.modify_params = True |
| 25 | + |
| 26 | + |
| 27 | +class LiteLLMModel(BaseModelWithPricing): |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + model_name: str, |
| 31 | + base_url: Optional[str] = None, |
| 32 | + api_key: Optional[str] = None, |
| 33 | + temperature: float | None = None, |
| 34 | + max_tokens: int | None = 100, |
| 35 | + use_only_first_toolcall: bool = False, |
| 36 | + ): |
| 37 | + super().__init__( |
| 38 | + model_name=model_name, |
| 39 | + temperature=temperature, |
| 40 | + max_tokens=max_tokens, |
| 41 | + ) |
| 42 | + self.action_space_as_tools = True # this should be a config |
| 43 | + client_args = {} |
| 44 | + if base_url is not None: |
| 45 | + client_args["base_url"] = base_url |
| 46 | + if api_key is not None: |
| 47 | + client_args["api_key"] = api_key |
| 48 | + self.client = partial(completion, **client_args) |
| 49 | + self.init_pricing_tracker(pricing_api="litellm") |
| 50 | + self.use_only_first_toolcall = use_only_first_toolcall |
| 51 | + try: |
| 52 | + self.litellm_info = litellm.get_model_info(model_name) |
| 53 | + # maybe log this in xray |
| 54 | + |
| 55 | + except Exception as e: |
| 56 | + logging.error(f"Failed to get litellm model info: {e}") |
| 57 | + |
| 58 | + def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion": |
| 59 | + """ |
| 60 | + Calls the LiteLLM API with the given payload. |
| 61 | +
|
| 62 | + Args: |
| 63 | + payload (APIPayload): The payload to send to the API. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + OpenAIChatCompletion: An object with the same keys as OpenAIChatCompletion. |
| 67 | + """ |
| 68 | + input = [] |
| 69 | + for msg in payload.messages: # type: ignore |
| 70 | + input.extend(msg.prepare_message()) |
| 71 | + api_params: Dict[str, Any] = { |
| 72 | + "model": self.model_name, |
| 73 | + "messages": input, |
| 74 | + } |
| 75 | + if self.temperature is not None: |
| 76 | + api_params["temperature"] = self.temperature |
| 77 | + |
| 78 | + if self.max_tokens is not None: |
| 79 | + api_params["max_completion_tokens"] = self.max_tokens |
| 80 | + |
| 81 | + if payload.tools is not None: |
| 82 | + api_params["tools"] = ( |
| 83 | + self.format_tools_for_chat_completion(payload.tools) |
| 84 | + if "function" not in payload.tools[0] # convert if responses_api_tools |
| 85 | + else payload.tools |
| 86 | + ) |
| 87 | + |
| 88 | + if payload.tool_choice is not None and payload.force_call_tool is None: |
| 89 | + api_params["tool_choice"] = ( |
| 90 | + "required" if payload.tool_choice in ("required", "any") else payload.tool_choice |
| 91 | + ) |
| 92 | + |
| 93 | + if payload.force_call_tool is not None: |
| 94 | + api_params["tool_choice"] = { |
| 95 | + "type": "function", |
| 96 | + "function": {"name": payload.force_call_tool}, |
| 97 | + } |
| 98 | + |
| 99 | + if payload.reasoning_effort is not None: |
| 100 | + api_params["reasoning_effort"] = payload.reasoning_effort |
| 101 | + |
| 102 | + if "tools" in api_params and payload.cache_tool_definition: |
| 103 | + api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore |
| 104 | + |
| 105 | + if payload.cache_complete_prompt: |
| 106 | + # Indicating cache control for the last message enables caching of the complete prompt. |
| 107 | + api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} |
| 108 | + |
| 109 | + response = self.client(**api_params, num_retries=5) |
| 110 | + |
| 111 | + return response # type: ignore |
| 112 | + |
| 113 | + def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput: |
| 114 | + think_output = self._extract_thinking_content_from_response(response) |
| 115 | + tool_calls = self._extract_tool_calls_from_response(response) |
| 116 | + |
| 117 | + if self.action_space_as_tools: |
| 118 | + env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore |
| 119 | + else: |
| 120 | + env_action = self._extract_env_actions_from_text_response(response) |
| 121 | + return LLMOutput( |
| 122 | + raw_response=response, |
| 123 | + think=think_output, |
| 124 | + action=env_action if env_action is not None else None, |
| 125 | + tool_calls=tool_calls if tool_calls is not None else None, |
| 126 | + ) |
| 127 | + |
| 128 | + def _extract_thinking_content_from_response( |
| 129 | + self, response: OpenAIChatCompletion, wrap_tag="think" |
| 130 | + ): |
| 131 | + """Extracts the content from the message, including reasoning if available. |
| 132 | + It wraps the reasoning around <think>...</think> for easy identification of reasoning content, |
| 133 | + When LLM produces 'text' and 'reasoning' in the same message. |
| 134 | + Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. |
| 135 | +
|
| 136 | + Args: |
| 137 | + response: The message object or dict containing content and reasoning. |
| 138 | + wrap_tag: The tag name to wrap reasoning content (default: "think"). |
| 139 | +
|
| 140 | + Returns: |
| 141 | + str: The extracted content with reasoning wrapped in specified tags. |
| 142 | + """ |
| 143 | + message = response.choices[0].message |
| 144 | + if not isinstance(message, dict): |
| 145 | + message = message.to_dict() |
| 146 | + |
| 147 | + reasoning_content = message.get("reasoning", None) |
| 148 | + msg_content = message.get("text", "") # works for Open-router |
| 149 | + if reasoning_content: |
| 150 | + # Wrap reasoning in <think> tags with newlines for clarity |
| 151 | + reasoning_content = f"<{wrap_tag}>{reasoning_content}</{wrap_tag}>\n" |
| 152 | + logging.debug("Extracting content from response.choices[i].message.reasoning") |
| 153 | + else: |
| 154 | + reasoning_content = "" |
| 155 | + return f"{reasoning_content}{msg_content}{message.get('content', '')}" |
| 156 | + |
| 157 | + def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None: |
| 158 | + """Extracts tool calls from the response.""" |
| 159 | + message = response.choices[0].message.to_dict() |
| 160 | + tool_calls = message.get("tool_calls", None) |
| 161 | + if tool_calls is None: |
| 162 | + return None |
| 163 | + tool_call_list = [] |
| 164 | + for tc in tool_calls: # type: ignore |
| 165 | + tool_call_list.append( |
| 166 | + ToolCall( |
| 167 | + name=tc["function"]["name"], |
| 168 | + arguments=json.loads(tc["function"]["arguments"]), |
| 169 | + raw_call=tc, |
| 170 | + ) |
| 171 | + ) |
| 172 | + if self.use_only_first_toolcall: |
| 173 | + break |
| 174 | + return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore |
| 175 | + |
| 176 | + def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: |
| 177 | + """Extracts actions from the response.""" |
| 178 | + if not toolcalls: |
| 179 | + return None |
| 180 | + |
| 181 | + actions = [ |
| 182 | + AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls |
| 183 | + ] |
| 184 | + actions = ( |
| 185 | + AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) |
| 186 | + if len(actions) > 1 |
| 187 | + else actions[0] |
| 188 | + ) |
| 189 | + return actions |
| 190 | + |
| 191 | + def _extract_env_actions_from_text_response( |
| 192 | + self, response: "OpenAIChatCompletion" |
| 193 | + ) -> str | None: |
| 194 | + """Extracts environment actions from the text response.""" |
| 195 | + # Use when action space is not given as tools. |
| 196 | + # TODO: Add support to pass action space as prompt in LiteLLM. |
| 197 | + # Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support |
| 198 | + pass |
| 199 | + |
| 200 | + @staticmethod |
| 201 | + def format_tools_for_chat_completion(tools): |
| 202 | + """Formats response tools format for OpenAI Chat Completion API. |
| 203 | + Why we need this? |
| 204 | + Ans: actionset.to_tool_description() in bgym only returns description |
| 205 | + format valid for OpenAI Response API. |
| 206 | +
|
| 207 | + Args: |
| 208 | + tools: List of tool descriptions to format for Chat Completion API. |
| 209 | +
|
| 210 | + Returns: |
| 211 | + Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None. |
| 212 | + """ |
| 213 | + formatted_tools = None |
| 214 | + if tools is not None: |
| 215 | + formatted_tools = [ |
| 216 | + { |
| 217 | + "type": tool["type"], |
| 218 | + "function": {k: tool[k] for k in ("name", "description", "parameters")}, |
| 219 | + } |
| 220 | + for tool in tools |
| 221 | + ] |
| 222 | + return formatted_tools |
| 223 | + |
| 224 | + |
| 225 | +class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder): |
| 226 | + """Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder.""" |
| 227 | + |
| 228 | + def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]: |
| 229 | + """Prepare the message for the OpenAI API.""" |
| 230 | + content = [] |
| 231 | + for item in self.content: |
| 232 | + content.append(self.convert_content_to_expected_format(item)) |
| 233 | + output = [{"role": self.role, "content": content}] |
| 234 | + return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall) |
| 235 | + |
| 236 | + def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]: |
| 237 | + """Handle the tool call response from the last raw response.""" |
| 238 | + if self.responded_tool_calls is None: |
| 239 | + raise ValueError("No tool calls found in responded_tool_calls") |
| 240 | + output = [] |
| 241 | + raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore |
| 242 | + if use_only_first_toolcall: |
| 243 | + raw_call.tool_calls = raw_call.tool_calls[:1] |
| 244 | + output.append(raw_call) # add raw calls to output |
| 245 | + for fn_call in self.responded_tool_calls: |
| 246 | + raw_call = fn_call.raw_call |
| 247 | + assert ( |
| 248 | + "image" not in fn_call.tool_response |
| 249 | + ), "Image output is not supported in function calls response." |
| 250 | + # a function_call_output dict has keys "role", "tool_call_id" and "content" |
| 251 | + tool_call_reponse = { |
| 252 | + "name": raw_call["function"]["name"], # required with OpenRouter |
| 253 | + "role": "tool", |
| 254 | + "tool_call_id": raw_call["id"], |
| 255 | + "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"], |
| 256 | + } |
| 257 | + output.append(tool_call_reponse) |
| 258 | + |
| 259 | + return output |
| 260 | + |
| 261 | + |
| 262 | +@dataclass |
| 263 | +class LiteLLMModelArgs(BaseModelArgs): |
| 264 | + """Serializable arguments for LiteLMMModel.""" |
| 265 | + |
| 266 | + api = "openai" # tool description format used by actionset.to_tool_description() in bgym |
| 267 | + base_url: Optional[str] = None |
| 268 | + api_key: Optional[str] = None |
| 269 | + use_only_first_toolcall: bool = False |
| 270 | + |
| 271 | + def make_model(self): |
| 272 | + return LiteLLMModel( |
| 273 | + model_name=self.model_name, |
| 274 | + base_url=self.base_url, |
| 275 | + api_key=self.api_key, |
| 276 | + max_tokens=self.max_new_tokens, |
| 277 | + temperature=self.temperature, |
| 278 | + use_only_first_toolcall=self.use_only_first_toolcall, |
| 279 | + ) |
| 280 | + |
| 281 | + def get_message_builder(self) -> Type[MessageBuilder]: |
| 282 | + """Returns a message builder for the LiteLMMModel.""" |
| 283 | + return LiteLLMAPIMessageBuilder |
| 284 | + |
| 285 | + |
| 286 | +if __name__ == "__main__": |
| 287 | + """ |
| 288 | + Some simple tests to run the LiteLLMModel with different models. |
| 289 | + """ |
| 290 | + |
| 291 | + import os |
| 292 | + |
| 293 | + from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs |
| 294 | + from agentlab.experiments.study import Study |
| 295 | + from agentlab.llm.litellm_api import LiteLLMModelArgs |
| 296 | + |
| 297 | + os.environ["LITELLM_LOG"] = "WARNING" |
| 298 | + |
| 299 | + def get_agent(model_name: str) -> ToolUseAgentArgs: |
| 300 | + return ToolUseAgentArgs( |
| 301 | + model_args=LiteLLMModelArgs( |
| 302 | + model_name=model_name, |
| 303 | + max_new_tokens=2000, |
| 304 | + temperature=None, |
| 305 | + ), |
| 306 | + config=DEFAULT_PROMPT_CONFIG, |
| 307 | + ) |
| 308 | + |
| 309 | + models = [ |
| 310 | + "openai/gpt-4.1", |
| 311 | + "openai/gpt-4.1-mini", |
| 312 | + "openai/gpt-4.1-nano", |
| 313 | + "openai/o3-2025-04-16", |
| 314 | + "anthropic/claude-3-7-sonnet-20250219", |
| 315 | + "anthropic/claude-sonnet-4-20250514", |
| 316 | + ## Add more models to test. |
| 317 | + ] |
| 318 | + agent_args = [get_agent(model) for model in models] |
| 319 | + |
| 320 | + study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING) |
| 321 | + study.run( |
| 322 | + n_jobs=5, |
| 323 | + parallel_backend="ray", |
| 324 | + strict_reproducibility=False, |
| 325 | + n_relaunch=3, |
| 326 | + ) |
0 commit comments