Skip to content

Commit 57c6cbf

Browse files
authored
Added custom is_finished_parser logic to Google Vertex AI customizati… (#1728)
# Title: Added a custom completion detection parser for Gemini models ## Description This PR updates Ragas model customization how to guide. It adds proper completion detection support for Google's Vertex AI Gemini models. Currently, Ragas systematically raises `LLMDidNotFinishException` with Gemini models because it doesn't correctly interpret Gemini's completion signals.
1 parent 34a7db2 commit 57c6cbf

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

docs/howtos/customizations/customize_models.md

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,12 @@ import google.auth
7070
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
7171
from ragas.llms import LangchainLLMWrapper
7272
from ragas.embeddings import LangchainEmbeddingsWrapper
73+
from langchain_core.outputs import LLMResult, ChatGeneration
7374

7475
config = {
7576
"project_id": "<your-project-id>",
76-
"chat_model_id": "gemini-1.0-pro-002",
77-
"embedding_model_id": "textembedding-gecko",
77+
"chat_model_id": "gemini-1.5-pro-002",
78+
"embedding_model_id": "text-embedding-005",
7879
}
7980

8081
# authenticate to GCP
@@ -89,7 +90,41 @@ vertextai_embeddings = VertexAIEmbeddings(
8990
credentials=creds, model_name=config["embedding_model_id"]
9091
)
9192

92-
vertextai_llm = LangchainLLMWrapper(vertextai_llm)
93+
# Create a custom is_finished_parser to capture Gemini generation completion signals
94+
def gemini_is_finished_parser(response: LLMResult) -> bool:
95+
is_finished_list = []
96+
for g in response.flatten():
97+
resp = g.generations[0][0]
98+
99+
# Check generation_info first
100+
if resp.generation_info is not None:
101+
finish_reason = resp.generation_info.get("finish_reason")
102+
if finish_reason is not None:
103+
is_finished_list.append(
104+
finish_reason in ["STOP", "MAX_TOKENS"]
105+
)
106+
continue
107+
108+
# Check response_metadata as fallback
109+
if isinstance(resp, ChatGeneration) and resp.message is not None:
110+
metadata = resp.message.response_metadata
111+
if metadata.get("finish_reason"):
112+
is_finished_list.append(
113+
metadata["finish_reason"] in ["STOP", "MAX_TOKENS"]
114+
)
115+
elif metadata.get("stop_reason"):
116+
is_finished_list.append(
117+
metadata["stop_reason"] in ["STOP", "MAX_TOKENS"]
118+
)
119+
120+
# If no finish reason found, default to True
121+
if not is_finished_list:
122+
is_finished_list.append(True)
123+
124+
return all(is_finished_list)
125+
126+
127+
vertextai_llm = LangchainLLMWrapper(vertextai_llm, is_finished_parser=gemini_is_finished_parser)
93128
vertextai_embeddings = LangchainEmbeddingsWrapper(vertextai_embeddings)
94129
```
95130
Yay! Now are you ready to use ragas with Google VertexAI endpoints

0 commit comments

Comments
 (0)