11from __future__ import annotations
22
3- import asyncio
43import logging
54import typing as t
65from abc import ABC , abstractmethod
76from dataclasses import dataclass , field
8- from functools import partial
97
108from langchain_community .chat_models .vertexai import ChatVertexAI
119from langchain_community .llms import VertexAI
1210from langchain_core .language_models import BaseLanguageModel
13- from langchain_core .outputs import Generation , LLMResult
11+ from langchain_core .outputs import ChatGeneration , Generation , LLMResult
1412from langchain_openai .chat_models import AzureChatOpenAI , ChatOpenAI
1513from langchain_openai .llms import AzureOpenAI , OpenAI
1614from langchain_openai .llms .base import BaseOpenAI
1715
16+ from ragas .exceptions import LLMDidNotFinishException
1817from ragas .integrations .helicone import helicone_config
19- from ragas .run_config import RunConfig , add_async_retry , add_retry
18+ from ragas .run_config import RunConfig , add_async_retry
2019
2120if t .TYPE_CHECKING :
2221 from langchain_core .callbacks import Callbacks
22+ from langchain_core .messages import BaseMessage
2323 from langchain_core .prompt_values import PromptValue
2424 from llama_index .core .base .llms .base import BaseLLM
2525
@@ -55,6 +55,12 @@ def get_temperature(self, n: int) -> float:
5555 """Return the temperature to use for completion based on n."""
5656 return 0.3 if n > 1 else 1e-8
5757
58+ def is_finished (self , response : LLMResult ) -> bool :
59+ logger .warning (
60+ f"is_finished not implemented for { self .__class__ .__name__ } . Will default to True."
61+ )
62+ return True
63+
5864 @abstractmethod
5965 def generate_text (
6066 self ,
@@ -82,36 +88,27 @@ async def generate(
8288 temperature : t .Optional [float ] = None ,
8389 stop : t .Optional [t .List [str ]] = None ,
8490 callbacks : Callbacks = None ,
85- is_async : bool = True ,
8691 ) -> LLMResult :
8792 """Generate text using the given event loop."""
8893
8994 if temperature is None :
9095 temperature = self .get_temperature (n )
9196
92- if is_async :
93- agenerate_text_with_retry = add_async_retry (
94- self .agenerate_text , self .run_config
95- )
96- return await agenerate_text_with_retry (
97- prompt = prompt ,
98- n = n ,
99- temperature = temperature ,
100- stop = stop ,
101- callbacks = callbacks ,
102- )
103- else :
104- loop = asyncio .get_event_loop ()
105- generate_text_with_retry = add_retry (self .generate_text , self .run_config )
106- generate_text = partial (
107- generate_text_with_retry ,
108- prompt = prompt ,
109- n = n ,
110- temperature = temperature ,
111- stop = stop ,
112- callbacks = callbacks ,
113- )
114- return await loop .run_in_executor (None , generate_text )
97+ agenerate_text_with_retry = add_async_retry (
98+ self .agenerate_text , self .run_config
99+ )
100+ result = await agenerate_text_with_retry (
101+ prompt = prompt ,
102+ n = n ,
103+ temperature = temperature ,
104+ stop = stop ,
105+ callbacks = callbacks ,
106+ )
107+
108+ # check there are no max_token issues
109+ if not self .is_finished (result ):
110+ raise LLMDidNotFinishException ()
111+ return result
115112
116113
117114class LangchainLLMWrapper (BaseRagasLLM ):
@@ -123,12 +120,57 @@ class LangchainLLMWrapper(BaseRagasLLM):
123120 """
124121
125122 def __init__ (
126- self , langchain_llm : BaseLanguageModel , run_config : t .Optional [RunConfig ] = None
123+ self ,
124+ langchain_llm : BaseLanguageModel ,
125+ run_config : t .Optional [RunConfig ] = None ,
126+ is_finished_parser : t .Optional [t .Callable [[LLMResult ], bool ]] = None ,
127127 ):
128128 self .langchain_llm = langchain_llm
129129 if run_config is None :
130130 run_config = RunConfig ()
131131 self .set_run_config (run_config )
132+ self .is_finished_parser = is_finished_parser
133+
134+ def is_finished (self , response : LLMResult ) -> bool :
135+ """
136+ Parse the response to check if the LLM finished by checking the finish_reason
137+ or stop_reason.
138+ """
139+ if self .is_finished_parser is not None :
140+ return self .is_finished_parser (response )
141+ # if no parser is provided default to our own
142+
143+ is_finished_list = []
144+ for g in response .flatten ():
145+ resp = g .generations [0 ][0 ]
146+ if resp .generation_info is not None :
147+ # generation_info is provided - so we parse that
148+
149+ # OpenAI uses "stop" to indicate that the generation is finished
150+ # and is stored in 'finish_reason' key in generation_info
151+ if resp .generation_info .get ("finish_reason" ) is not None :
152+ is_finished_list .append (
153+ resp .generation_info .get ("finish_reason" ) == "stop"
154+ )
155+ # provied more conditions here
156+ # https://github.com/explodinggradients/ragas/issues/1548
157+
158+ # if generation_info is empty, we parse the response_metadata
159+ # this is less reliable
160+ elif t .cast (ChatGeneration , resp ).message is not None :
161+ resp_message : BaseMessage = t .cast (ChatGeneration , resp ).message
162+ if resp_message .response_metadata .get ("finish_reason" ) is not None :
163+ is_finished_list .append (
164+ resp_message .response_metadata .get ("finish_reason" ) == "stop"
165+ )
166+ elif resp_message .response_metadata .get ("stop_reason" ) is not None :
167+ is_finished_list .append (
168+ resp_message .response_metadata .get ("stop_reason" ) == "end_turn"
169+ )
170+ # default to True
171+ else :
172+ is_finished_list .append (True )
173+ return all (is_finished_list )
132174
133175 def generate_text (
134176 self ,
0 commit comments