Skip to content

style: format code with Ruff Formatter #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 8 additions & 28 deletions app/temp/qdrant_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,13 @@ class QdrantVectorSearchTool(BaseTool):
qdrant_api_key: Authentication key for Qdrant
"""

model_config: ClassVar[dict[str, bool]] = {
"arbitrary_types_allowed": True
} # Add ClassVar annotation
model_config: ClassVar[dict[str, bool]] = {"arbitrary_types_allowed": True} # Add ClassVar annotation
client: QdrantClient = None
async_client: AsyncQdrantClient = None
openai_client: Any = None # Added for lazy initialization
openai_async_client: Any = None # Added for lazy initialization
name: str = "QdrantVectorSearchTool"
description: str = (
"A tool to search the Qdrant database for relevant information on internal documents."
)
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
args_schema: type[BaseModel] = QdrantToolSchema
query: str | None = None
filter_by: str | None = None
Expand Down Expand Up @@ -146,11 +142,7 @@ def _run(
# Create filter if filter parameters are provided
search_filter = None
if filter_by and filter_value:
search_filter = Filter(
must=[
FieldCondition(key=filter_by, match=MatchValue(value=filter_value))
]
)
search_filter = Filter(must=[FieldCondition(key=filter_by, match=MatchValue(value=filter_value))])

# Search in Qdrant using the built-in query method
query_vector = (
Expand Down Expand Up @@ -192,9 +184,7 @@ def _vectorize_query_sync(self, query: str, embedding_model: str) -> list[float]
from openai import Client

# Define error messages as constants
openai_api_key_not_set_error_msg = (
"OPENAI_API_KEY environment variable is not set."
)
openai_api_key_not_set_error_msg = "OPENAI_API_KEY environment variable is not set."
# Lazy initialization of the sync client
if not self.openai_client:
api_key = os.getenv("OPENAI_API_KEY")
Expand Down Expand Up @@ -240,17 +230,11 @@ async def _arun(
# Create filter if filter parameters are provided
search_filter = None
if filter_by and filter_value:
search_filter = Filter(
must=[
FieldCondition(key=filter_by, match=MatchValue(value=filter_value))
]
)
search_filter = Filter(must=[FieldCondition(key=filter_by, match=MatchValue(value=filter_value))])

# Search in Qdrant using the built-in query method
query_vector = (
await self._vectorize_query_async(
query, embedding_model="text-embedding-3-large"
)
await self._vectorize_query_async(query, embedding_model="text-embedding-3-large")
if not self.custom_embedding_fn
else self.custom_embedding_fn(query)
)
Expand All @@ -275,9 +259,7 @@ async def _arun(

return json.dumps(results, indent=2)

async def _vectorize_query_async(
self, query: str, embedding_model: str
) -> list[float]:
async def _vectorize_query_async(self, query: str, embedding_model: str) -> list[float]:
"""Default async vectorization function with openai.

Args:
Expand All @@ -290,9 +272,7 @@ async def _vectorize_query_async(
from openai import AsyncClient

# Define error messages as constants
openai_api_key_not_set_error_msg = (
"OPENAI_API_KEY environment variable is not set."
)
openai_api_key_not_set_error_msg = "OPENAI_API_KEY environment variable is not set."
# Lazy initialization of the async client
if not self.openai_async_client:
api_key = os.getenv("OPENAI_API_KEY")
Expand Down
Loading