Skip to content

Commit 27a9056

Browse files
authored
community: Fix ChatLiteLLMRouter runtime issues (langchain-ai#28163)
**Description:** Fix ChatLiteLLMRouter ctor validation and model_name parameter **Issue:** langchain-ai#19356, langchain-ai#27455, langchain-ai#28077 **Twitter handle:** @bburgin_0
1 parent 234d496 commit 27a9056

File tree

3 files changed

+115
-71
lines changed

3 files changed

+115
-71
lines changed

docs/docs/integrations/chat/litellm_router.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@
6363
" },\n",
6464
" },\n",
6565
" {\n",
66-
" \"model_name\": \"gpt-4\",\n",
66+
" \"model_name\": \"gpt-35-turbo\",\n",
6767
" \"litellm_params\": {\n",
68-
" \"model\": \"azure/gpt-4-1106-preview\",\n",
68+
" \"model\": \"azure/gpt-35-turbo\",\n",
6969
" \"api_key\": \"<your-api-key>\",\n",
7070
" \"api_version\": \"2023-05-15\",\n",
7171
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
7272
" },\n",
7373
" },\n",
7474
"]\n",
7575
"litellm_router = Router(model_list=model_list)\n",
76-
"chat = ChatLiteLLMRouter(router=litellm_router)"
76+
"chat = ChatLiteLLMRouter(router=litellm_router, model_name=\"gpt-35-turbo\")"
7777
]
7878
},
7979
{
@@ -177,6 +177,7 @@
177177
"source": [
178178
"chat = ChatLiteLLMRouter(\n",
179179
" router=litellm_router,\n",
180+
" model_name=\"gpt-35-turbo\",\n",
180181
" streaming=True,\n",
181182
" verbose=True,\n",
182183
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
@@ -209,7 +210,7 @@
209210
"name": "python",
210211
"nbconvert_exporter": "python",
211212
"pygments_lexer": "ipython3",
212-
"version": "3.9.13"
213+
"version": "3.11.9"
213214
}
214215
},
215216
"nbformat": 4,

libs/community/langchain_community/chat_models/litellm_router.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
"""LiteLLM Router as LangChain Model."""
22

3-
from typing import (
4-
Any,
5-
AsyncIterator,
6-
Iterator,
7-
List,
8-
Mapping,
9-
Optional,
10-
)
3+
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
114

125
from langchain_core.callbacks.manager import (
136
AsyncCallbackManagerForLLMRun,
@@ -17,24 +10,17 @@
1710
agenerate_from_stream,
1811
generate_from_stream,
1912
)
20-
from langchain_core.messages import (
21-
AIMessageChunk,
22-
BaseMessage,
23-
)
24-
from langchain_core.outputs import (
25-
ChatGeneration,
26-
ChatGenerationChunk,
27-
ChatResult,
28-
)
13+
from langchain_core.messages import AIMessageChunk, BaseMessage
14+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
2915

3016
from langchain_community.chat_models.litellm import (
3117
ChatLiteLLM,
3218
_convert_delta_to_message_chunk,
3319
_convert_dict_to_message,
3420
)
3521

36-
token_usage_key_name = "token_usage"
37-
model_extra_key_name = "model_extra"
22+
token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
23+
model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
3824

3925

4026
def get_llm_output(usage: Any, **params: Any) -> dict:
@@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
5642

5743
def __init__(self, *, router: Any, **kwargs: Any) -> None:
5844
"""Construct Chat LiteLLM Router."""
59-
super().__init__(**kwargs)
45+
super().__init__(router=router, **kwargs) # type: ignore
6046
self.router = router
6147

6248
@property
6349
def _llm_type(self) -> str:
6450
return "LiteLLMRouter"
6551

66-
def _set_model_for_completion(self) -> None:
67-
# use first model name (aka: model group),
68-
# since we can only pass one to the router completion functions
69-
self.model = self.router.model_list[0]["model_name"]
70-
7152
def _prepare_params_for_router(self, params: Any) -> None:
72-
params["model"] = self.model
73-
7453
# allow the router to set api_base based on its model choice
7554
api_base_key_name = "api_base"
7655
if api_base_key_name in params and params[api_base_key_name] is None:
@@ -79,6 +58,22 @@ def _prepare_params_for_router(self, params: Any) -> None:
7958
# add metadata so router can fill it below
8059
params.setdefault("metadata", {})
8160

61+
def set_default_model(self, model_name: str) -> None:
62+
"""Set the default model to use for completion calls.
63+
64+
Sets `self.model` to `model_name` if it is in the litellm router's
65+
(`self.router`) model list. This provides the default model to use
66+
for completion calls if no `model` kwarg is provided.
67+
"""
68+
model_list = self.router.model_list
69+
if not model_list:
70+
raise ValueError("model_list is None or empty.")
71+
for entry in model_list:
72+
if entry["model_name"] == model_name:
73+
self.model = model_name
74+
return
75+
raise ValueError(f"Model {model_name} not found in model_list.")
76+
8277
def _generate(
8378
self,
8479
messages: List[BaseMessage],
@@ -96,7 +91,6 @@ def _generate(
9691

9792
message_dicts, params = self._create_message_dicts(messages, stop)
9893
params = {**params, **kwargs}
99-
self._set_model_for_completion()
10094
self._prepare_params_for_router(params)
10195

10296
response = self.router.completion(
@@ -115,7 +109,6 @@ def _stream(
115109
default_chunk_class = AIMessageChunk
116110
message_dicts, params = self._create_message_dicts(messages, stop)
117111
params = {**params, **kwargs, "stream": True}
118-
self._set_model_for_completion()
119112
self._prepare_params_for_router(params)
120113

121114
for chunk in self.router.completion(messages=message_dicts, **params):
@@ -139,7 +132,6 @@ async def _astream(
139132
default_chunk_class = AIMessageChunk
140133
message_dicts, params = self._create_message_dicts(messages, stop)
141134
params = {**params, **kwargs, "stream": True}
142-
self._set_model_for_completion()
143135
self._prepare_params_for_router(params)
144136

145137
async for chunk in await self.router.acompletion(
@@ -174,7 +166,6 @@ async def _agenerate(
174166

175167
message_dicts, params = self._create_message_dicts(messages, stop)
176168
params = {**params, **kwargs}
177-
self._set_model_for_completion()
178169
self._prepare_params_for_router(params)
179170

180171
response = await self.router.acompletion(
@@ -196,14 +187,14 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
196187
token_usage = output["token_usage"]
197188
if token_usage is not None:
198189
# get dict from LiteLLM Usage class
199-
for k, v in token_usage.dict().items():
200-
if k in overall_token_usage:
190+
for k, v in token_usage.model_dump().items():
191+
if k in overall_token_usage and overall_token_usage[k] is not None:
201192
overall_token_usage[k] += v
202193
else:
203194
overall_token_usage[k] = v
204195
if system_fingerprint is None:
205196
system_fingerprint = output.get("system_fingerprint")
206-
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
197+
combined = {"token_usage": overall_token_usage, "model_name": self.model}
207198
if system_fingerprint:
208199
combined["system_fingerprint"] = system_fingerprint
209200
return combined

0 commit comments

Comments
 (0)