66try :
77 from langchain_litellm import ChatLiteLLM , ChatLiteLLMRouter
88 from litellm import Router
9+
910 LANGCHAIN_AVAILABLE = True
1011except ImportError :
1112 LANGCHAIN_AVAILABLE = False
@@ -59,7 +60,7 @@ def __init__(
5960 router : Optional [Any ] = None , # Router from litellm
6061 temperature : float = 0.0 ,
6162 max_tokens : Optional [int ] = None ,
62- ** kwargs
63+ ** kwargs ,
6364 ):
6465 if not LANGCHAIN_AVAILABLE :
6566 raise ImportError (
@@ -74,28 +75,27 @@ def __init__(
7475
7576 # Initialize the appropriate LangChain client
7677 if router :
77- logger .info (f"Initializing LangChainLiteLLMClient with router for model: { model } " )
78+ logger .info (
79+ f"Initializing LangChainLiteLLMClient with router for model: { model } "
80+ )
7881 self .llm = ChatLiteLLMRouter (
7982 router = router ,
8083 model_name = model ,
8184 temperature = temperature ,
8285 max_tokens = max_tokens ,
83- ** kwargs
86+ ** kwargs ,
8487 )
8588 self .is_router = True
8689 else :
8790 logger .info (f"Initializing LangChainLiteLLMClient for model: { model } " )
8891 self .llm = ChatLiteLLM (
89- model = model ,
90- temperature = temperature ,
91- max_tokens = max_tokens ,
92- ** kwargs
92+ model = model , temperature = temperature , max_tokens = max_tokens , ** kwargs
9393 )
9494 self .is_router = False
9595
9696 def _filter_kwargs (self , kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
9797 """Filter out ACE-specific parameters that shouldn't go to LangChain."""
98- ace_specific_params = {' refinement_round' , ' max_refinement_rounds' }
98+ ace_specific_params = {" refinement_round" , " max_refinement_rounds" }
9999 return {k : v for k , v in kwargs .items () if k not in ace_specific_params }
100100
101101 def complete (self , prompt : str , ** kwargs ) -> LLMResponse :
@@ -121,7 +121,7 @@ def complete(self, prompt: str, **kwargs) -> LLMResponse:
121121 }
122122
123123 # Add usage information if available
124- if hasattr (response , ' usage_metadata' ) and response .usage_metadata :
124+ if hasattr (response , " usage_metadata" ) and response .usage_metadata :
125125 metadata ["usage" ] = {
126126 "prompt_tokens" : response .usage_metadata .get ("input_tokens" ),
127127 "completion_tokens" : response .usage_metadata .get ("output_tokens" ),
@@ -131,12 +131,11 @@ def complete(self, prompt: str, **kwargs) -> LLMResponse:
131131 # Add router information if using router
132132 if self .is_router :
133133 metadata ["router" ] = True
134- metadata ["model_used" ] = response .response_metadata .get ("model_name" , self .model )
134+ metadata ["model_used" ] = response .response_metadata .get (
135+ "model_name" , self .model
136+ )
135137
136- return LLMResponse (
137- text = response .content ,
138- raw = metadata
139- )
138+ return LLMResponse (text = response .content , raw = metadata )
140139
141140 except Exception as e :
142141 logger .error (f"Error in LangChain completion: { e } " )
@@ -165,7 +164,7 @@ async def acomplete(self, prompt: str, **kwargs) -> LLMResponse:
165164 }
166165
167166 # Add usage information if available
168- if hasattr (response , ' usage_metadata' ) and response .usage_metadata :
167+ if hasattr (response , " usage_metadata" ) and response .usage_metadata :
169168 metadata ["usage" ] = {
170169 "prompt_tokens" : response .usage_metadata .get ("input_tokens" ),
171170 "completion_tokens" : response .usage_metadata .get ("output_tokens" ),
@@ -175,12 +174,11 @@ async def acomplete(self, prompt: str, **kwargs) -> LLMResponse:
175174 # Add router information if using router
176175 if self .is_router :
177176 metadata ["router" ] = True
178- metadata ["model_used" ] = response .response_metadata .get ("model_name" , self .model )
177+ metadata ["model_used" ] = response .response_metadata .get (
178+ "model_name" , self .model
179+ )
179180
180- return LLMResponse (
181- text = response .content ,
182- raw = metadata
183- )
181+ return LLMResponse (text = response .content , raw = metadata )
184182
185183 except Exception as e :
186184 logger .error (f"Error in async LangChain completion: { e } " )
@@ -201,7 +199,7 @@ def complete_with_stream(self, prompt: str, **kwargs) -> Iterator[str]:
201199
202200 try :
203201 for chunk in self .llm .stream (prompt , ** filtered_kwargs ):
204- if hasattr (chunk , ' content' ) and chunk .content :
202+ if hasattr (chunk , " content" ) and chunk .content :
205203 yield chunk .content
206204 except Exception as e :
207205 logger .error (f"Error in LangChain streaming: { e } " )
@@ -222,8 +220,8 @@ async def acomplete_with_stream(self, prompt: str, **kwargs) -> AsyncIterator[st
222220
223221 try :
224222 async for chunk in self .llm .astream (prompt , ** filtered_kwargs ):
225- if hasattr (chunk , ' content' ) and chunk .content :
223+ if hasattr (chunk , " content" ) and chunk .content :
226224 yield chunk .content
227225 except Exception as e :
228226 logger .error (f"Error in async LangChain streaming: { e } " )
229- raise
227+ raise
0 commit comments