diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index 0b48f05a..fb0f7d2e 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -71,6 +71,12 @@ class VerifyGqlOutput(BaseModel): INTERMEDIATE_STEPS_KEY = "intermediate_steps" +class InvalidGQLGenerationError(ValueError): + def __init__(self, message, intermediate_steps=None): + self.intermediate_steps = intermediate_steps + super().__init__(message) + + class SpannerGraphQAChain(Chain): """Chain for question-answering against a Spanner Graph database by generating GQL statements from natural language questions. @@ -268,7 +274,9 @@ def execute_with_retry( finally: retries += 1 - raise ValueError("The generated gql query is invalid") + raise InvalidGQLGenerationError( + "The generated gql query is invalid", intermediate_steps + ) def log_invalid_query( self, @@ -322,7 +330,9 @@ def _call( _run_manager, intermediate_steps, question, verified_gql ) if not final_gql: - raise ValueError("No GQL was generated.") + raise InvalidGQLGenerationError( + "No GQL was generated.", intermediate_steps + ) _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) _run_manager.on_text( str(context), color="green", end="\n", verbose=self.verbose