Skip to content

Commit 5641493

Browse files
Add LiteLLM API integration and related tests
- Implement LiteLLMModel and LiteLLMAPIMessageBuilder for LiteLLM API interaction. - Enhance APIPayload to include reasoning_effort parameter. - Introduce get_pricing_litellm function for pricing information retrieval. - Create tests for multi-action, single tool call scenarios and force tool call.
1 parent 3c2422e commit 5641493

File tree

4 files changed

+512
-3
lines changed

4 files changed

+512
-3
lines changed

src/agentlab/llm/litellm_api.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
import json
2+
import logging
3+
from dataclasses import dataclass
4+
from functools import partial
5+
from typing import Any, Dict, List, Optional, Type
6+
7+
import litellm
8+
from litellm import completion
9+
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
10+
11+
from agentlab.llm.base_api import BaseModelArgs
12+
from agentlab.llm.response_api import (
13+
AgentlabAction,
14+
APIPayload,
15+
BaseModelWithPricing,
16+
LLMOutput,
17+
Message,
18+
MessageBuilder,
19+
OpenAIChatCompletionAPIMessageBuilder,
20+
ToolCall,
21+
ToolCalls,
22+
)
23+
24+
litellm.modify_params = True
25+
26+
27+
class LiteLLMModel(BaseModelWithPricing):
28+
def __init__(
29+
self,
30+
model_name: str,
31+
base_url: Optional[str] = None,
32+
api_key: Optional[str] = None,
33+
temperature: float | None = None,
34+
max_tokens: int | None = 100,
35+
use_only_first_toolcall: bool = False,
36+
):
37+
super().__init__(
38+
model_name=model_name,
39+
temperature=temperature,
40+
max_tokens=max_tokens,
41+
)
42+
self.action_space_as_tools = True # this should be a config
43+
client_args = {}
44+
if base_url is not None:
45+
client_args["base_url"] = base_url
46+
if api_key is not None:
47+
client_args["api_key"] = api_key
48+
self.client = partial(completion, **client_args)
49+
self.init_pricing_tracker(pricing_api="litellm")
50+
self.use_only_first_toolcall = use_only_first_toolcall
51+
try:
52+
self.litellm_info = litellm.get_model_info(model_name)
53+
# maybe log this in xray
54+
55+
except Exception as e:
56+
logging.error(f"Failed to get litellm model info: {e}")
57+
58+
def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion":
59+
"""Calls the LiteLLM API with the given payload. LiteLLM API is a wrapper around OpenAI's API.
60+
61+
Args:
62+
payload (APIPayload)
63+
64+
Returns:
65+
OpenAIChatCompletion: A OpenAIChatCompletion Like object with the same keys
66+
"""
67+
input = []
68+
for msg in payload.messages: # type: ignore
69+
input.extend(msg.prepare_message())
70+
api_params: Dict[str, Any] = {
71+
"model": self.model_name,
72+
"messages": input,
73+
}
74+
if self.temperature is not None:
75+
api_params["temperature"] = self.temperature
76+
77+
if self.max_tokens is not None:
78+
api_params["max_completion_tokens"] = self.max_tokens
79+
80+
if payload.tools is not None:
81+
api_params["tools"] = (
82+
self.format_tools_for_chat_completion(payload.tools)
83+
if "function" not in payload.tools[0] # convert if responses_api_tools
84+
else payload.tools
85+
)
86+
87+
if payload.tool_choice is not None and payload.force_call_tool is None:
88+
api_params["tool_choice"] = (
89+
"required" if payload.tool_choice in ("required", "any") else payload.tool_choice
90+
)
91+
92+
if payload.force_call_tool is not None:
93+
api_params["tool_choice"] = {
94+
"type": "function",
95+
"function": {"name": payload.force_call_tool},
96+
}
97+
98+
if payload.reasoning_effort is not None:
99+
api_params["reasoning_effort"] = payload.reasoning_effort
100+
101+
if "tools" in api_params and payload.cache_tool_definition:
102+
api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore
103+
104+
if payload.cache_complete_prompt:
105+
# Indicating cache control for the last message enables caching of the complete prompt.
106+
api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
107+
108+
response = self.client(**api_params, num_retries=5)
109+
110+
return response # type: ignore
111+
112+
def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput:
113+
think_output = self._extract_thinking_content_from_response(response)
114+
tool_calls = self._extract_tool_calls_from_response(response)
115+
116+
if self.action_space_as_tools:
117+
env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore
118+
else:
119+
env_action = self._extract_env_actions_from_text_response(response)
120+
return LLMOutput(
121+
raw_response=response,
122+
think=think_output,
123+
action=env_action if env_action is not None else None,
124+
tool_calls=tool_calls if tool_calls is not None else None,
125+
)
126+
127+
def _extract_thinking_content_from_response(
128+
self, response: OpenAIChatCompletion, wrap_tag="think"
129+
):
130+
"""Extracts the content from the message, including reasoning if available.
131+
It wraps the reasoning around <think>...</think> for easy identification of reasoning content,
132+
When LLM produces 'text' and 'reasoning' in the same message.
133+
Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered.
134+
135+
Args:
136+
response: The message object or dict containing content and reasoning.
137+
wrap_tag: The tag name to wrap reasoning content (default: "think").
138+
139+
Returns:
140+
str: The extracted content with reasoning wrapped in specified tags.
141+
"""
142+
message = response.choices[0].message
143+
if not isinstance(message, dict):
144+
message = message.to_dict()
145+
146+
reasoning_content = message.get("reasoning", None)
147+
msg_content = message.get("text", "") # works for Open-router
148+
if reasoning_content:
149+
# Wrap reasoning in <think> tags with newlines for clarity
150+
reasoning_content = f"<{wrap_tag}>{reasoning_content}</{wrap_tag}>\n"
151+
logging.debug("Extracting content from response.choices[i].message.reasoning")
152+
else:
153+
reasoning_content = ""
154+
return f"{reasoning_content}{msg_content}{message.get('content', '')}"
155+
156+
def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None:
157+
"""Extracts tool calls from the response."""
158+
message = response.choices[0].message.to_dict()
159+
tool_calls = message.get("tool_calls", None)
160+
if tool_calls is None:
161+
return None
162+
tool_call_list = []
163+
for tc in tool_calls: # type: ignore
164+
tool_call_list.append(
165+
ToolCall(
166+
name=tc["function"]["name"],
167+
arguments=json.loads(tc["function"]["arguments"]),
168+
raw_call=tc,
169+
)
170+
)
171+
if self.use_only_first_toolcall:
172+
break
173+
return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore
174+
175+
def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
176+
"""Extracts actions from the response."""
177+
if not toolcalls:
178+
return None
179+
180+
actions = [
181+
AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
182+
]
183+
actions = (
184+
AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
185+
if len(actions) > 1
186+
else actions[0]
187+
)
188+
return actions
189+
190+
def _extract_env_actions_from_text_response(
191+
self, response: "OpenAIChatCompletion"
192+
) -> str | None:
193+
"""Extracts environment actions from the text response."""
194+
# Use when action space is not given as tools.
195+
# TODO: Add support to pass action space as prompt in LiteLLM.
196+
# Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support
197+
pass
198+
199+
@staticmethod
200+
def format_tools_for_chat_completion(tools):
201+
"""Formats response tools format for OpenAI Chat Completion API.
202+
Why we need this?
203+
Ans: actionset.to_tool_description() in bgym only returns description
204+
format valid for OpenAI Response API.
205+
206+
Args:
207+
tools: List of tool descriptions to format for Chat Completion API.
208+
209+
Returns:
210+
Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None.
211+
"""
212+
formatted_tools = None
213+
if tools is not None:
214+
formatted_tools = [
215+
{
216+
"type": tool["type"],
217+
"function": {k: tool[k] for k in ("name", "description", "parameters")},
218+
}
219+
for tool in tools
220+
]
221+
return formatted_tools
222+
223+
224+
class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder):
225+
"""Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder."""
226+
227+
def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]:
228+
"""Prepare the message for the OpenAI API."""
229+
content = []
230+
for item in self.content:
231+
content.append(self.convert_content_to_expected_format(item))
232+
output = [{"role": self.role, "content": content}]
233+
return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall)
234+
235+
def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]:
236+
"""Handle the tool call response from the last raw response."""
237+
if self.responded_tool_calls is None:
238+
raise ValueError("No tool calls found in responded_tool_calls")
239+
output = []
240+
raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore
241+
if use_only_first_toolcall:
242+
raw_call.tool_calls = raw_call.tool_calls[:1]
243+
output.append(raw_call) # add raw calls to output
244+
for fn_call in self.responded_tool_calls:
245+
raw_call = fn_call.raw_call
246+
assert (
247+
"image" not in fn_call.tool_response
248+
), "Image output is not supported in function calls response."
249+
# a function_call_output dict has keys "role", "tool_call_id" and "content"
250+
tool_call_reponse = {
251+
"name": raw_call["function"]["name"], # required with OpenRouter
252+
"role": "tool",
253+
"tool_call_id": raw_call["id"],
254+
"content": self.convert_content_to_expected_format(fn_call.tool_response)["text"],
255+
}
256+
output.append(tool_call_reponse)
257+
258+
return output
259+
260+
261+
@dataclass
262+
class LiteLLMModelArgs(BaseModelArgs):
263+
"""Serializable arguments for LiteLMMModel."""
264+
265+
api = "openai" # tool description format used by actionset.to_tool_description() in bgym
266+
base_url: Optional[str] = None
267+
api_key: Optional[str] = None
268+
use_only_first_toolcall: bool = False
269+
270+
def make_model(self):
271+
return LiteLLMModel(
272+
model_name=self.model_name,
273+
base_url=self.base_url,
274+
api_key=self.api_key,
275+
max_tokens=self.max_new_tokens,
276+
temperature=self.temperature,
277+
use_only_first_toolcall=self.use_only_first_toolcall,
278+
)
279+
280+
def get_message_builder(self) -> Type[MessageBuilder]:
281+
"""Returns a message builder for the LiteLMMModel."""
282+
return LiteLLMAPIMessageBuilder
283+
284+
285+
if __name__ == "__main__":
286+
"""
287+
Some simple tests to run the LiteLLMModel with different models.
288+
"""
289+
290+
from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs
291+
from agentlab.experiments.study import Study
292+
from agentlab.llm.litellm_api import LiteLLMModelArgs
293+
294+
def get_agent(model_name: str) -> ToolUseAgentArgs:
295+
return ToolUseAgentArgs(
296+
model_args=LiteLLMModelArgs(
297+
model_name=model_name,
298+
max_new_tokens=2000,
299+
temperature=None,
300+
),
301+
config=DEFAULT_PROMPT_CONFIG,
302+
)
303+
304+
models = [
305+
"openai/gpt-4.1",
306+
"openai/gpt-4.1-mini",
307+
"openai/gpt-4.1-nano",
308+
"openai/o3-2025-04-16",
309+
"anthropic/claude-3-7-sonnet-20250219",
310+
"anthropic/claude-sonnet-4-20250514",
311+
## Add more models to test.
312+
]
313+
agent_args = [get_agent(model) for model in models]
314+
315+
316+
study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING)
317+
study.run(
318+
n_jobs=5,
319+
parallel_backend="ray",
320+
strict_reproducibility=False,
321+
n_relaunch=3,
322+
)

src/agentlab/llm/response_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models.
2828
"""
2929

30+
logger = logging.getLogger(__name__)
3031

3132
ContentItem = Dict[str, Any]
3233
Message = Dict[str, Union[str, List[ContentItem]]]
@@ -388,10 +389,13 @@ class APIPayload:
388389
cache_complete_prompt: bool = (
389390
False # If True, will cache the complete prompt in the last message.
390391
)
392+
reasoning_effort: Literal["low", "medium", "high"] | None = None
391393

392394
def __post_init__(self):
393395
if self.tool_choice and self.force_call_tool:
394396
raise ValueError("tool_choice and force_call_tool are mutually exclusive")
397+
if self.reasoning_effort is not None:
398+
logger.info('In agentlab reasoning_effort is used by LiteLLM API only. We will eventually shift to LiteLLM API for all LLMs.')
395399

396400

397401
# # Base class for all API Endpoints

0 commit comments

Comments
 (0)