diff --git a/app/temp/qdrant_search_tool.py b/app/temp/qdrant_search_tool.py index 66e7b5c..a2c8e32 100644 --- a/app/temp/qdrant_search_tool.py +++ b/app/temp/qdrant_search_tool.py @@ -54,17 +54,13 @@ class QdrantVectorSearchTool(BaseTool): qdrant_api_key: Authentication key for Qdrant """ - model_config: ClassVar[dict[str, bool]] = { - "arbitrary_types_allowed": True - } # Add ClassVar annotation + model_config: ClassVar[dict[str, bool]] = {"arbitrary_types_allowed": True} # Add ClassVar annotation client: QdrantClient = None async_client: AsyncQdrantClient = None openai_client: Any = None # Added for lazy initialization openai_async_client: Any = None # Added for lazy initialization name: str = "QdrantVectorSearchTool" - description: str = ( - "A tool to search the Qdrant database for relevant information on internal documents." - ) + description: str = "A tool to search the Qdrant database for relevant information on internal documents." args_schema: type[BaseModel] = QdrantToolSchema query: str | None = None filter_by: str | None = None @@ -146,11 +142,7 @@ def _run( # Create filter if filter parameters are provided search_filter = None if filter_by and filter_value: - search_filter = Filter( - must=[ - FieldCondition(key=filter_by, match=MatchValue(value=filter_value)) - ] - ) + search_filter = Filter(must=[FieldCondition(key=filter_by, match=MatchValue(value=filter_value))]) # Search in Qdrant using the built-in query method query_vector = ( @@ -192,9 +184,7 @@ def _vectorize_query_sync(self, query: str, embedding_model: str) -> list[float] from openai import Client # Define error messages as constants - openai_api_key_not_set_error_msg = ( - "OPENAI_API_KEY environment variable is not set." - ) + openai_api_key_not_set_error_msg = "OPENAI_API_KEY environment variable is not set." # Lazy initialization of the sync client if not self.openai_client: api_key = os.getenv("OPENAI_API_KEY") @@ -240,17 +230,11 @@ async def _arun( # Create filter if filter parameters are provided search_filter = None if filter_by and filter_value: - search_filter = Filter( - must=[ - FieldCondition(key=filter_by, match=MatchValue(value=filter_value)) - ] - ) + search_filter = Filter(must=[FieldCondition(key=filter_by, match=MatchValue(value=filter_value))]) # Search in Qdrant using the built-in query method query_vector = ( - await self._vectorize_query_async( - query, embedding_model="text-embedding-3-large" - ) + await self._vectorize_query_async(query, embedding_model="text-embedding-3-large") if not self.custom_embedding_fn else self.custom_embedding_fn(query) ) @@ -275,9 +259,7 @@ async def _arun( return json.dumps(results, indent=2) - async def _vectorize_query_async( - self, query: str, embedding_model: str - ) -> list[float]: + async def _vectorize_query_async(self, query: str, embedding_model: str) -> list[float]: """Default async vectorization function with openai. Args: @@ -290,9 +272,7 @@ async def _vectorize_query_async( from openai import AsyncClient # Define error messages as constants - openai_api_key_not_set_error_msg = ( - "OPENAI_API_KEY environment variable is not set." - ) + openai_api_key_not_set_error_msg = "OPENAI_API_KEY environment variable is not set." # Lazy initialization of the async client if not self.openai_async_client: api_key = os.getenv("OPENAI_API_KEY")