@@ -129,7 +129,7 @@ class LangchainLLMWrapper(BaseRagasLLM):
129129
130130 def __init__ (
131131 self ,
132- langchain_llm : BaseLanguageModel ,
132+ langchain_llm : BaseLanguageModel [ BaseMessage ] ,
133133 run_config : t .Optional [RunConfig ] = None ,
134134 is_finished_parser : t .Optional [t .Callable [[LLMResult ], bool ]] = None ,
135135 cache : t .Optional [CacheInterface ] = None ,
@@ -198,29 +198,36 @@ def generate_text(
198198 callbacks : Callbacks = None ,
199199 ) -> LLMResult :
200200 # figure out the temperature to set
201+ old_temperature : float | None = None
201202 if temperature is None :
202203 temperature = self .get_temperature (n = n )
204+ if hasattr (self .langchain_llm , "temperature" ):
205+ self .langchain_llm .temperature = temperature # type: ignore
206+ old_temperature = temperature
203207
204208 if is_multiple_completion_supported (self .langchain_llm ):
205- return self .langchain_llm .generate_prompt (
209+ result = self .langchain_llm .generate_prompt (
206210 prompts = [prompt ],
207211 n = n ,
208- temperature = temperature ,
209212 stop = stop ,
210213 callbacks = callbacks ,
211214 )
212215 else :
213216 result = self .langchain_llm .generate_prompt (
214217 prompts = [prompt ] * n ,
215- temperature = temperature ,
216218 stop = stop ,
217219 callbacks = callbacks ,
218220 )
219221 # make LLMResult.generation appear as if it was n_completions
220222 # note that LLMResult.runs is still a list that represents each run
221223 generations = [[g [0 ] for g in result .generations ]]
222224 result .generations = generations
223- return result
225+
226+ # reset the temperature to the original value
227+ if old_temperature is not None :
228+ self .langchain_llm .temperature = old_temperature # type: ignore
229+
230+ return result
224231
225232 async def agenerate_text (
226233 self ,
@@ -230,29 +237,38 @@ async def agenerate_text(
230237 stop : t .Optional [t .List [str ]] = None ,
231238 callbacks : Callbacks = None ,
232239 ) -> LLMResult :
240+ # handle temperature
241+ old_temperature : float | None = None
233242 if temperature is None :
234243 temperature = self .get_temperature (n = n )
244+ if hasattr (self .langchain_llm , "temperature" ):
245+ self .langchain_llm .temperature = temperature # type: ignore
246+ old_temperature = temperature
235247
236- if is_multiple_completion_supported (self .langchain_llm ):
237- return await self .langchain_llm .agenerate_prompt (
248+ # handle n
249+ if hasattr (self .langchain_llm , "n" ):
250+ self .langchain_llm .n = n # type: ignore
251+ result = await self .langchain_llm .agenerate_prompt (
238252 prompts = [prompt ],
239- n = n ,
240- temperature = temperature ,
241253 stop = stop ,
242254 callbacks = callbacks ,
243255 )
244256 else :
245257 result = await self .langchain_llm .agenerate_prompt (
246258 prompts = [prompt ] * n ,
247- temperature = temperature ,
248259 stop = stop ,
249260 callbacks = callbacks ,
250261 )
251262 # make LLMResult.generation appear as if it was n_completions
252263 # note that LLMResult.runs is still a list that represents each run
253264 generations = [[g [0 ] for g in result .generations ]]
254265 result .generations = generations
255- return result
266+
267+ # reset the temperature to the original value
268+ if old_temperature is not None :
269+ self .langchain_llm .temperature = old_temperature # type: ignore
270+
271+ return result
256272
257273 def set_run_config (self , run_config : RunConfig ):
258274 self .run_config = run_config
0 commit comments