Skip to content

Commit 0c0e6bd

Browse files
Merge pull request #187 from neo4j-labs/googlellmvalidation
Add custom validation for Google Gemini LLM
2 parents d471ee6 + e0fef67 commit 0c0e6bd

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

backend/src/gemini_llm.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from langchain_community.graphs import Neo4jGraph
22
from dotenv import load_dotenv
33
from langchain.schema import Document
4+
import json
45
import logging
56
import re
67
import concurrent.futures
@@ -18,7 +19,7 @@
1819
from langchain.pydantic_v1 import BaseModel
1920
from langchain_google_vertexai import ChatVertexAI
2021
from langchain_core.prompts import ChatPromptTemplate
21-
from typing import List
22+
from typing import List, Dict
2223
from langchain_core.pydantic_v1 import BaseModel, Field
2324
import google.auth
2425
from langchain_community.graphs.graph_document import Node
@@ -150,6 +151,44 @@ def map_to_base_relationship(rel: Any) -> Relationship:
150151
source=source, target=target, type=rel.type.replace(" ", "_").upper()
151152
)
152153

154+
def _extract_relationships(
155+
raw_schema: Dict[Any, Any],
156+
) -> List[Relationship]:
157+
# If there are validation errors
158+
if not raw_schema["parsed"]:
159+
try:
160+
argument_json = json.loads(
161+
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
162+
"arguments"
163+
]
164+
)
165+
relationships = []
166+
for rel in argument_json["relationships"]:
167+
# Mandatory props
168+
if (
169+
not rel.get("start_node_id")
170+
or not rel.get("end_node_id")
171+
or not rel.get("type")
172+
):
173+
continue
174+
relationships.append(
175+
Relationship(
176+
source=Node(id=rel.get("start_node_id").title(), type=rel.get("start_node_type", "Node").capitalize()),
177+
target=Node(id=rel.get("end_node_id").title(), type=rel.get("end_node_type", "Node").capitalize()),
178+
type=rel["type"],
179+
)
180+
)
181+
except Exception: # If we can't parse JSON
182+
return []
183+
else: # If there are no validation errors use parsed pydantic object
184+
parsed_schema = raw_schema["parsed"]
185+
relationships = (
186+
[map_to_base_relationship(rel) for rel in parsed_schema.relationships]
187+
if parsed_schema.relationships
188+
else []
189+
)
190+
return relationships
191+
153192
class GeminiLLMGraphTransformer:
154193
def __init__(
155194
self,
@@ -170,7 +209,7 @@ def __init__(
170209

171210
# Define chain
172211
schema = create_simple_model(allowed_nodes, allowed_relationships)
173-
structured_llm = llm.with_structured_output(schema)
212+
structured_llm = llm.with_structured_output(schema, include_raw=True)
174213
self.chain = default_prompt | structured_llm
175214

176215

@@ -186,12 +225,7 @@ def process_response(self, document: Document) -> GraphDocument:
186225
# else:
187226
# nodes = []
188227
nodes =[]
189-
if raw_schema.relationships:
190-
relationships = [
191-
map_to_base_relationship(rel) for rel in raw_schema.relationships
192-
]
193-
else:
194-
relationships = []
228+
relationships = _extract_relationships(raw_schema)
195229

196230
# Strict mode filtering
197231
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):

0 commit comments

Comments
 (0)