11"""LiteLLM Router as LangChain Model."""
22
3- from typing import (
4- Any ,
5- AsyncIterator ,
6- Iterator ,
7- List ,
8- Mapping ,
9- Optional ,
10- )
3+ from typing import Any , AsyncIterator , Iterator , List , Mapping , Optional
114
125from langchain_core .callbacks .manager import (
136 AsyncCallbackManagerForLLMRun ,
1710 agenerate_from_stream ,
1811 generate_from_stream ,
1912)
20- from langchain_core .messages import (
21- AIMessageChunk ,
22- BaseMessage ,
23- )
24- from langchain_core .outputs import (
25- ChatGeneration ,
26- ChatGenerationChunk ,
27- ChatResult ,
28- )
13+ from langchain_core .messages import AIMessageChunk , BaseMessage
14+ from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , ChatResult
2915
3016from langchain_community .chat_models .litellm import (
3117 ChatLiteLLM ,
3218 _convert_delta_to_message_chunk ,
3319 _convert_dict_to_message ,
3420)
3521
36- token_usage_key_name = "token_usage"
37- model_extra_key_name = "model_extra"
22+ token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
23+ model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
3824
3925
4026def get_llm_output (usage : Any , ** params : Any ) -> dict :
@@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
5642
5743 def __init__ (self , * , router : Any , ** kwargs : Any ) -> None :
5844 """Construct Chat LiteLLM Router."""
59- super ().__init__ (** kwargs )
45+ super ().__init__ (router = router , ** kwargs ) # type: ignore
6046 self .router = router
6147
6248 @property
6349 def _llm_type (self ) -> str :
6450 return "LiteLLMRouter"
6551
66- def _set_model_for_completion (self ) -> None :
67- # use first model name (aka: model group),
68- # since we can only pass one to the router completion functions
69- self .model = self .router .model_list [0 ]["model_name" ]
70-
7152 def _prepare_params_for_router (self , params : Any ) -> None :
72- params ["model" ] = self .model
73-
7453 # allow the router to set api_base based on its model choice
7554 api_base_key_name = "api_base"
7655 if api_base_key_name in params and params [api_base_key_name ] is None :
@@ -79,6 +58,22 @@ def _prepare_params_for_router(self, params: Any) -> None:
7958 # add metadata so router can fill it below
8059 params .setdefault ("metadata" , {})
8160
61+ def set_default_model (self , model_name : str ) -> None :
62+ """Set the default model to use for completion calls.
63+
64+ Sets `self.model` to `model_name` if it is in the litellm router's
65+ (`self.router`) model list. This provides the default model to use
66+ for completion calls if no `model` kwarg is provided.
67+ """
68+ model_list = self .router .model_list
69+ if not model_list :
70+ raise ValueError ("model_list is None or empty." )
71+ for entry in model_list :
72+ if entry ["model_name" ] == model_name :
73+ self .model = model_name
74+ return
75+ raise ValueError (f"Model { model_name } not found in model_list." )
76+
8277 def _generate (
8378 self ,
8479 messages : List [BaseMessage ],
@@ -96,7 +91,6 @@ def _generate(
9691
9792 message_dicts , params = self ._create_message_dicts (messages , stop )
9893 params = {** params , ** kwargs }
99- self ._set_model_for_completion ()
10094 self ._prepare_params_for_router (params )
10195
10296 response = self .router .completion (
@@ -115,7 +109,6 @@ def _stream(
115109 default_chunk_class = AIMessageChunk
116110 message_dicts , params = self ._create_message_dicts (messages , stop )
117111 params = {** params , ** kwargs , "stream" : True }
118- self ._set_model_for_completion ()
119112 self ._prepare_params_for_router (params )
120113
121114 for chunk in self .router .completion (messages = message_dicts , ** params ):
@@ -139,7 +132,6 @@ async def _astream(
139132 default_chunk_class = AIMessageChunk
140133 message_dicts , params = self ._create_message_dicts (messages , stop )
141134 params = {** params , ** kwargs , "stream" : True }
142- self ._set_model_for_completion ()
143135 self ._prepare_params_for_router (params )
144136
145137 async for chunk in await self .router .acompletion (
@@ -174,7 +166,6 @@ async def _agenerate(
174166
175167 message_dicts , params = self ._create_message_dicts (messages , stop )
176168 params = {** params , ** kwargs }
177- self ._set_model_for_completion ()
178169 self ._prepare_params_for_router (params )
179170
180171 response = await self .router .acompletion (
@@ -196,14 +187,14 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
196187 token_usage = output ["token_usage" ]
197188 if token_usage is not None :
198189 # get dict from LiteLLM Usage class
199- for k , v in token_usage .dict ().items ():
200- if k in overall_token_usage :
190+ for k , v in token_usage .model_dump ().items ():
191+ if k in overall_token_usage and overall_token_usage [ k ] is not None :
201192 overall_token_usage [k ] += v
202193 else :
203194 overall_token_usage [k ] = v
204195 if system_fingerprint is None :
205196 system_fingerprint = output .get ("system_fingerprint" )
206- combined = {"token_usage" : overall_token_usage , "model_name" : self .model_name }
197+ combined = {"token_usage" : overall_token_usage , "model_name" : self .model }
207198 if system_fingerprint :
208199 combined ["system_fingerprint" ] = system_fingerprint
209200 return combined
0 commit comments