Skip to content

Commit 219e467

Browse files
authored
Merge pull request #273 from ServiceNow/litellm
Add Litellm API integration
2 parents 2255590 + 971a3e9 commit 219e467

File tree

4 files changed

+519
-4
lines changed

4 files changed

+519
-4
lines changed

src/agentlab/llm/litellm_api.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import json
2+
import logging
3+
from dataclasses import dataclass
4+
from functools import partial
5+
from typing import Any, Dict, List, Optional, Type
6+
7+
import litellm
8+
from litellm import completion
9+
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
10+
11+
from agentlab.llm.base_api import BaseModelArgs
12+
from agentlab.llm.response_api import (
13+
AgentlabAction,
14+
APIPayload,
15+
BaseModelWithPricing,
16+
LLMOutput,
17+
Message,
18+
MessageBuilder,
19+
OpenAIChatCompletionAPIMessageBuilder,
20+
ToolCall,
21+
ToolCalls,
22+
)
23+
24+
litellm.modify_params = True
25+
26+
27+
class LiteLLMModel(BaseModelWithPricing):
28+
def __init__(
29+
self,
30+
model_name: str,
31+
base_url: Optional[str] = None,
32+
api_key: Optional[str] = None,
33+
temperature: float | None = None,
34+
max_tokens: int | None = 100,
35+
use_only_first_toolcall: bool = False,
36+
):
37+
super().__init__(
38+
model_name=model_name,
39+
temperature=temperature,
40+
max_tokens=max_tokens,
41+
)
42+
self.action_space_as_tools = True # this should be a config
43+
client_args = {}
44+
if base_url is not None:
45+
client_args["base_url"] = base_url
46+
if api_key is not None:
47+
client_args["api_key"] = api_key
48+
self.client = partial(completion, **client_args)
49+
self.init_pricing_tracker(pricing_api="litellm")
50+
self.use_only_first_toolcall = use_only_first_toolcall
51+
try:
52+
self.litellm_info = litellm.get_model_info(model_name)
53+
# maybe log this in xray
54+
55+
except Exception as e:
56+
logging.error(f"Failed to get litellm model info: {e}")
57+
58+
def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion":
59+
"""
60+
Calls the LiteLLM API with the given payload.
61+
62+
Args:
63+
payload (APIPayload): The payload to send to the API.
64+
65+
Returns:
66+
OpenAIChatCompletion: An object with the same keys as OpenAIChatCompletion.
67+
"""
68+
input = []
69+
for msg in payload.messages: # type: ignore
70+
input.extend(msg.prepare_message())
71+
api_params: Dict[str, Any] = {
72+
"model": self.model_name,
73+
"messages": input,
74+
}
75+
if self.temperature is not None:
76+
api_params["temperature"] = self.temperature
77+
78+
if self.max_tokens is not None:
79+
api_params["max_completion_tokens"] = self.max_tokens
80+
81+
if payload.tools is not None:
82+
api_params["tools"] = (
83+
self.format_tools_for_chat_completion(payload.tools)
84+
if "function" not in payload.tools[0] # convert if responses_api_tools
85+
else payload.tools
86+
)
87+
88+
if payload.tool_choice is not None and payload.force_call_tool is None:
89+
api_params["tool_choice"] = (
90+
"required" if payload.tool_choice in ("required", "any") else payload.tool_choice
91+
)
92+
93+
if payload.force_call_tool is not None:
94+
api_params["tool_choice"] = {
95+
"type": "function",
96+
"function": {"name": payload.force_call_tool},
97+
}
98+
99+
if payload.reasoning_effort is not None:
100+
api_params["reasoning_effort"] = payload.reasoning_effort
101+
102+
if "tools" in api_params and payload.cache_tool_definition:
103+
api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore
104+
105+
if payload.cache_complete_prompt:
106+
# Indicating cache control for the last message enables caching of the complete prompt.
107+
api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
108+
109+
response = self.client(**api_params, num_retries=5)
110+
111+
return response # type: ignore
112+
113+
def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput:
114+
think_output = self._extract_thinking_content_from_response(response)
115+
tool_calls = self._extract_tool_calls_from_response(response)
116+
117+
if self.action_space_as_tools:
118+
env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore
119+
else:
120+
env_action = self._extract_env_actions_from_text_response(response)
121+
return LLMOutput(
122+
raw_response=response,
123+
think=think_output,
124+
action=env_action if env_action is not None else None,
125+
tool_calls=tool_calls if tool_calls is not None else None,
126+
)
127+
128+
def _extract_thinking_content_from_response(
129+
self, response: OpenAIChatCompletion, wrap_tag="think"
130+
):
131+
"""Extracts the content from the message, including reasoning if available.
132+
It wraps the reasoning around <think>...</think> for easy identification of reasoning content,
133+
When LLM produces 'text' and 'reasoning' in the same message.
134+
Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered.
135+
136+
Args:
137+
response: The message object or dict containing content and reasoning.
138+
wrap_tag: The tag name to wrap reasoning content (default: "think").
139+
140+
Returns:
141+
str: The extracted content with reasoning wrapped in specified tags.
142+
"""
143+
message = response.choices[0].message
144+
if not isinstance(message, dict):
145+
message = message.to_dict()
146+
147+
reasoning_content = message.get("reasoning", None)
148+
msg_content = message.get("text", "") # works for Open-router
149+
if reasoning_content:
150+
# Wrap reasoning in <think> tags with newlines for clarity
151+
reasoning_content = f"<{wrap_tag}>{reasoning_content}</{wrap_tag}>\n"
152+
logging.debug("Extracting content from response.choices[i].message.reasoning")
153+
else:
154+
reasoning_content = ""
155+
return f"{reasoning_content}{msg_content}{message.get('content', '')}"
156+
157+
def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None:
158+
"""Extracts tool calls from the response."""
159+
message = response.choices[0].message.to_dict()
160+
tool_calls = message.get("tool_calls", None)
161+
if tool_calls is None:
162+
return None
163+
tool_call_list = []
164+
for tc in tool_calls: # type: ignore
165+
tool_call_list.append(
166+
ToolCall(
167+
name=tc["function"]["name"],
168+
arguments=json.loads(tc["function"]["arguments"]),
169+
raw_call=tc,
170+
)
171+
)
172+
if self.use_only_first_toolcall:
173+
break
174+
return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore
175+
176+
def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
177+
"""Extracts actions from the response."""
178+
if not toolcalls:
179+
return None
180+
181+
actions = [
182+
AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
183+
]
184+
actions = (
185+
AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
186+
if len(actions) > 1
187+
else actions[0]
188+
)
189+
return actions
190+
191+
def _extract_env_actions_from_text_response(
192+
self, response: "OpenAIChatCompletion"
193+
) -> str | None:
194+
"""Extracts environment actions from the text response."""
195+
# Use when action space is not given as tools.
196+
# TODO: Add support to pass action space as prompt in LiteLLM.
197+
# Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support
198+
pass
199+
200+
@staticmethod
201+
def format_tools_for_chat_completion(tools):
202+
"""Formats response tools format for OpenAI Chat Completion API.
203+
Why we need this?
204+
Ans: actionset.to_tool_description() in bgym only returns description
205+
format valid for OpenAI Response API.
206+
207+
Args:
208+
tools: List of tool descriptions to format for Chat Completion API.
209+
210+
Returns:
211+
Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None.
212+
"""
213+
formatted_tools = None
214+
if tools is not None:
215+
formatted_tools = [
216+
{
217+
"type": tool["type"],
218+
"function": {k: tool[k] for k in ("name", "description", "parameters")},
219+
}
220+
for tool in tools
221+
]
222+
return formatted_tools
223+
224+
225+
class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder):
226+
"""Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder."""
227+
228+
def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]:
229+
"""Prepare the message for the OpenAI API."""
230+
content = []
231+
for item in self.content:
232+
content.append(self.convert_content_to_expected_format(item))
233+
output = [{"role": self.role, "content": content}]
234+
return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall)
235+
236+
def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]:
237+
"""Handle the tool call response from the last raw response."""
238+
if self.responded_tool_calls is None:
239+
raise ValueError("No tool calls found in responded_tool_calls")
240+
output = []
241+
raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore
242+
if use_only_first_toolcall:
243+
raw_call.tool_calls = raw_call.tool_calls[:1]
244+
output.append(raw_call) # add raw calls to output
245+
for fn_call in self.responded_tool_calls:
246+
raw_call = fn_call.raw_call
247+
assert (
248+
"image" not in fn_call.tool_response
249+
), "Image output is not supported in function calls response."
250+
# a function_call_output dict has keys "role", "tool_call_id" and "content"
251+
tool_call_reponse = {
252+
"name": raw_call["function"]["name"], # required with OpenRouter
253+
"role": "tool",
254+
"tool_call_id": raw_call["id"],
255+
"content": self.convert_content_to_expected_format(fn_call.tool_response)["text"],
256+
}
257+
output.append(tool_call_reponse)
258+
259+
return output
260+
261+
262+
@dataclass
263+
class LiteLLMModelArgs(BaseModelArgs):
264+
"""Serializable arguments for LiteLMMModel."""
265+
266+
api = "openai" # tool description format used by actionset.to_tool_description() in bgym
267+
base_url: Optional[str] = None
268+
api_key: Optional[str] = None
269+
use_only_first_toolcall: bool = False
270+
271+
def make_model(self):
272+
return LiteLLMModel(
273+
model_name=self.model_name,
274+
base_url=self.base_url,
275+
api_key=self.api_key,
276+
max_tokens=self.max_new_tokens,
277+
temperature=self.temperature,
278+
use_only_first_toolcall=self.use_only_first_toolcall,
279+
)
280+
281+
def get_message_builder(self) -> Type[MessageBuilder]:
282+
"""Returns a message builder for the LiteLMMModel."""
283+
return LiteLLMAPIMessageBuilder
284+
285+
286+
if __name__ == "__main__":
287+
"""
288+
Some simple tests to run the LiteLLMModel with different models.
289+
"""
290+
291+
import os
292+
293+
from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs
294+
from agentlab.experiments.study import Study
295+
from agentlab.llm.litellm_api import LiteLLMModelArgs
296+
297+
os.environ["LITELLM_LOG"] = "WARNING"
298+
299+
def get_agent(model_name: str) -> ToolUseAgentArgs:
300+
return ToolUseAgentArgs(
301+
model_args=LiteLLMModelArgs(
302+
model_name=model_name,
303+
max_new_tokens=2000,
304+
temperature=None,
305+
),
306+
config=DEFAULT_PROMPT_CONFIG,
307+
)
308+
309+
models = [
310+
"openai/gpt-4.1",
311+
"openai/gpt-4.1-mini",
312+
"openai/gpt-4.1-nano",
313+
"openai/o3-2025-04-16",
314+
"anthropic/claude-3-7-sonnet-20250219",
315+
"anthropic/claude-sonnet-4-20250514",
316+
## Add more models to test.
317+
]
318+
agent_args = [get_agent(model) for model in models]
319+
320+
study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING)
321+
study.run(
322+
n_jobs=5,
323+
parallel_backend="ray",
324+
strict_reproducibility=False,
325+
n_relaunch=3,
326+
)

src/agentlab/llm/response_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models.
2828
"""
2929

30+
logger = logging.getLogger(__name__)
3031

3132
ContentItem = Dict[str, Any]
3233
Message = Dict[str, Union[str, List[ContentItem]]]
@@ -388,10 +389,15 @@ class APIPayload:
388389
cache_complete_prompt: bool = (
389390
False # If True, will cache the complete prompt in the last message.
390391
)
392+
reasoning_effort: Literal["low", "medium", "high"] | None = None
391393

392394
def __post_init__(self):
393395
if self.tool_choice and self.force_call_tool:
394396
raise ValueError("tool_choice and force_call_tool are mutually exclusive")
397+
if self.reasoning_effort is not None:
398+
logger.info(
399+
"In agentlab reasoning_effort is used by LiteLLM API only. We will eventually shift to LiteLLM API for all LLMs."
400+
)
395401

396402

397403
# # Base class for all API Endpoints

0 commit comments

Comments
 (0)