1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
23
23
from llama_index .core .prompts import PromptType
24
24
from llama_index .core .retrievers import CustomPGRetriever , VectorContextRetriever
25
25
from llama_index .core .schema import NodeWithScore , QueryBundle , TextNode
26
- from llama_index .core .vector_stores .types import VectorStore
26
+ from llama_index .core .vector_stores .types import BasePydanticVectorStore
27
27
from pydantic import BaseModel
28
28
29
29
from .graph_utils import extract_gql , fix_gql_syntax
@@ -78,7 +78,6 @@ def __init__(
78
78
graph_store : SpannerPropertyGraphStore ,
79
79
llm : Optional [LLM ] = None ,
80
80
text_to_gql_prompt : Optional [PromptTemplate ] = None ,
81
- response_template : Optional [str ] = None ,
82
81
gql_validator : Optional [Callable [[str ], bool ]] = None ,
83
82
include_raw_response_as_metadata : Optional [bool ] = False ,
84
83
max_gql_fix_retries : Optional [int ] = 1 ,
@@ -93,7 +92,6 @@ def __init__(
93
92
graph_store: The SpannerPropertyGraphStore to query.
94
93
llm: The LLM to use.
95
94
text_to_gql_prompt: The prompt to use for generating the GQL query.
96
- response_template: The template to use for formatting the response.
97
95
gql_validator: A function to validate the GQL query.
98
96
include_raw_response_as_metadata: If true, includes the raw response as
99
97
metadata.
@@ -179,7 +177,7 @@ def calculate_score_for_predicted_response(
179
177
gql_response_score = self .llm .predict (
180
178
GQL_RESPONSE_SCORING_TEMPLATE , question = question , retrieved_context = response
181
179
)
182
- return gql_response_score
180
+ return float ( gql_response_score . strip ())
183
181
184
182
def retrieve_from_graph (
185
183
self , query_bundle : schema .QueryBundle
@@ -208,16 +206,19 @@ def retrieve_from_graph(
208
206
209
207
# 2. Verify gql query using LLM
210
208
if self .verify_gql :
211
- verify_response = self .llm .predict (
212
- GQL_VERIFY_PROMPT ,
213
- question = question ,
214
- generated_gql = generated_gql ,
215
- schema = schema_str ,
216
- format_instructions = GQL_VERIFY_PROMPT .output_parser .format_string ,
217
- )
209
+ if GQL_VERIFY_PROMPT .output_parser :
210
+ verify_response = self .llm .predict (
211
+ GQL_VERIFY_PROMPT ,
212
+ question = question ,
213
+ generated_gql = generated_gql ,
214
+ schema = schema_str ,
215
+ format_instructions = GQL_VERIFY_PROMPT .output_parser .format ,
216
+ )
218
217
219
- output_parser = verify_gql_output_parser .parse (verify_response )
220
- verified_gql = fix_gql_syntax (output_parser .verified_gql )
218
+ output_parser = verify_gql_output_parser .parse (verify_response )
219
+ verified_gql = fix_gql_syntax (output_parser .verified_gql )
220
+ else :
221
+ raise ValueError ("GQL_VERIFY_PROMPT is missing its output_parser." )
221
222
else :
222
223
verified_gql = generated_gql
223
224
@@ -259,7 +260,7 @@ def retrieve_from_graph(
259
260
async def aretrieve_from_graph (
260
261
self , query_bundle : QueryBundle
261
262
) -> List [NodeWithScore ]:
262
- return await self .retrieve_from_graph (query_bundle )
263
+ return self .retrieve_from_graph (query_bundle )
263
264
264
265
265
266
class SpannerGraphCustomRetriever (CustomPGRetriever ):
@@ -269,13 +270,12 @@ def init(
269
270
self ,
270
271
## vector context retriever params
271
272
embed_model : Optional [BaseEmbedding ] = None ,
272
- vector_store : Optional [VectorStore ] = None ,
273
+ vector_store : Optional [BasePydanticVectorStore ] = None ,
273
274
similarity_top_k : int = 4 ,
274
275
path_depth : int = 2 ,
275
276
## text-to-gql params
276
277
llm_text_to_gql : Optional [LLM ] = None ,
277
278
text_to_gql_prompt : Optional [PromptTemplate ] = None ,
278
- response_template : Optional [str ] = None ,
279
279
gql_validator : Optional [Callable [[str ], bool ]] = None ,
280
280
include_raw_response_as_metadata : Optional [bool ] = False ,
281
281
max_gql_fix_retries : Optional [int ] = 1 ,
@@ -297,7 +297,6 @@ def init(
297
297
path_depth: The depth of the path to retrieve.
298
298
llm_text_to_gql: The LLM to use for text to GQL conversion.
299
299
text_to_gql_prompt: The prompt to use for generating the GQL query.
300
- response_template: The template to use for formatting the response.
301
300
gql_validator: A function to validate the GQL query.
302
301
include_raw_response_as_metadata: Whether to include the raw response as
303
302
metadata.
@@ -311,6 +310,12 @@ def init(
311
310
llmranker_top_n: The number of top nodes to return.
312
311
**kwargs: Additional keyword arguments.
313
312
"""
313
+
314
+ if not isinstance (self ._graph_store , SpannerPropertyGraphStore ):
315
+ raise TypeError (
316
+ "SpannerGraphCustomRetriever requires a SpannerPropertyGraphStore."
317
+ )
318
+
314
319
self .llm = llm_text_to_gql or Settings .llm
315
320
if self .llm is None :
316
321
raise ValueError ("`llm for Text to GQL` cannot be none" )
@@ -328,7 +333,6 @@ def init(
328
333
graph_store = self ._graph_store ,
329
334
llm = llm_text_to_gql ,
330
335
text_to_gql_prompt = text_to_gql_prompt ,
331
- response_template = response_template ,
332
336
gql_validator = gql_validator ,
333
337
include_raw_response_as_metadata = include_raw_response_as_metadata ,
334
338
max_gql_fix_retries = max_gql_fix_retries ,
@@ -342,7 +346,7 @@ def init(
342
346
top_n = llmranker_top_n ,
343
347
)
344
348
345
- def generate_synthesized_response (self , question : str , response : str ) -> float :
349
+ def generate_synthesized_response (self , question : str , response : str ) -> str :
346
350
gql_synthesized_response = self .llm .predict (
347
351
GQL_SYNTHESIS_RESPONSE_TEMPLATE ,
348
352
question = question ,
0 commit comments