11from langchain_community .graphs import Neo4jGraph
22from dotenv import load_dotenv
33from langchain .schema import Document
4+ import json
45import logging
56import re
67import concurrent .futures
1819from langchain .pydantic_v1 import BaseModel
1920from langchain_google_vertexai import ChatVertexAI
2021from langchain_core .prompts import ChatPromptTemplate
21- from typing import List
22+ from typing import List , Dict
2223from langchain_core .pydantic_v1 import BaseModel , Field
2324import google .auth
2425from langchain_community .graphs .graph_document import Node
@@ -150,6 +151,44 @@ def map_to_base_relationship(rel: Any) -> Relationship:
150151 source = source , target = target , type = rel .type .replace (" " , "_" ).upper ()
151152 )
152153
154+ def _extract_relationships (
155+ raw_schema : Dict [Any , Any ],
156+ ) -> List [Relationship ]:
157+ # If there are validation errors
158+ if not raw_schema ["parsed" ]:
159+ try :
160+ argument_json = json .loads (
161+ raw_schema ["raw" ].additional_kwargs ["tool_calls" ][0 ]["function" ][
162+ "arguments"
163+ ]
164+ )
165+ relationships = []
166+ for rel in argument_json ["relationships" ]:
167+ # Mandatory props
168+ if (
169+ not rel .get ("start_node_id" )
170+ or not rel .get ("end_node_id" )
171+ or not rel .get ("type" )
172+ ):
173+ continue
174+ relationships .append (
175+ Relationship (
176+ source = Node (id = rel .get ("start_node_id" ).title (), type = rel .get ("start_node_type" , "Node" ).capitalize ()),
177+ target = Node (id = rel .get ("end_node_id" ).title (), type = rel .get ("end_node_type" , "Node" ).capitalize ()),
178+ type = rel ["type" ],
179+ )
180+ )
181+ except Exception : # If we can't parse JSON
182+ return []
183+ else : # If there are no validation errors use parsed pydantic object
184+ parsed_schema = raw_schema ["parsed" ]
185+ relationships = (
186+ [map_to_base_relationship (rel ) for rel in parsed_schema .relationships ]
187+ if parsed_schema .relationships
188+ else []
189+ )
190+ return relationships
191+
153192class GeminiLLMGraphTransformer :
154193 def __init__ (
155194 self ,
@@ -170,7 +209,7 @@ def __init__(
170209
171210 # Define chain
172211 schema = create_simple_model (allowed_nodes , allowed_relationships )
173- structured_llm = llm .with_structured_output (schema )
212+ structured_llm = llm .with_structured_output (schema , include_raw = True )
174213 self .chain = default_prompt | structured_llm
175214
176215
@@ -186,12 +225,7 @@ def process_response(self, document: Document) -> GraphDocument:
186225 # else:
187226 # nodes = []
188227 nodes = []
189- if raw_schema .relationships :
190- relationships = [
191- map_to_base_relationship (rel ) for rel in raw_schema .relationships
192- ]
193- else :
194- relationships = []
228+ relationships = _extract_relationships (raw_schema )
195229
196230 # Strict mode filtering
197231 if self .strict_mode and (self .allowed_nodes or self .allowed_relationships ):
0 commit comments