Skip to content

Commit 2ee27ef

Browse files
feat: improvements on QdrantVectorSearchTool
* Implement improvements on QdrantVectorSearchTool - Allow search filters to be set at the constructor level - Fix issue that prevented multiple records from being returned * Implement improvements on QdrantVectorSearchTool - Allow search filters to be set at the constructor level - Fix issue that prevented multiple records from being returned --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
1 parent f6e13eb commit 2ee27ef

File tree

1 file changed

+83
-147
lines changed

1 file changed

+83
-147
lines changed
Lines changed: 83 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,42 @@
1-
from collections.abc import Callable
1+
from __future__ import annotations
2+
3+
import importlib
24
import json
35
import os
6+
from collections.abc import Callable
47
from typing import Any
58

6-
7-
try:
8-
from qdrant_client import QdrantClient
9-
from qdrant_client.http.models import FieldCondition, Filter, MatchValue
10-
11-
QDRANT_AVAILABLE = True
12-
except ImportError:
13-
QDRANT_AVAILABLE = False
14-
QdrantClient = Any # type: ignore[assignment,misc] # type placeholder
15-
Filter = Any # type: ignore[assignment,misc]
16-
FieldCondition = Any # type: ignore[assignment,misc]
17-
MatchValue = Any # type: ignore[assignment,misc]
18-
199
from crewai.tools import BaseTool, EnvVar
20-
from pydantic import BaseModel, ConfigDict, Field
10+
from pydantic import BaseModel, ConfigDict, Field, model_validator
11+
from pydantic.types import ImportString
2112

2213

2314
class QdrantToolSchema(BaseModel):
24-
"""Input for QdrantTool."""
15+
query: str = Field(..., description="Query to search in Qdrant DB.")
16+
filter_by: str | None = None
17+
filter_value: str | None = None
2518

26-
query: str = Field(
27-
...,
28-
description="The query to search retrieve relevant information from the Qdrant database. Pass only the query, not the question.",
29-
)
30-
filter_by: str | None = Field(
31-
default=None,
32-
description="Filter by properties. Pass only the properties, not the question.",
33-
)
34-
filter_value: str | None = Field(
35-
default=None,
36-
description="Filter by value. Pass only the value, not the question.",
37-
)
3819

20+
class QdrantConfig(BaseModel):
21+
"""All Qdrant connection and search settings."""
3922

40-
class QdrantVectorSearchTool(BaseTool):
41-
"""Tool to query and filter results from a Qdrant database.
23+
qdrant_url: str
24+
qdrant_api_key: str | None = None
25+
collection_name: str
26+
limit: int = 3
27+
score_threshold: float = 0.35
28+
filter_conditions: list[tuple[str, Any]] = Field(default_factory=list)
4229

43-
This tool enables vector similarity search on internal documents stored in Qdrant,
44-
with optional filtering capabilities.
4530

46-
Attributes:
47-
client: Configured QdrantClient instance
48-
collection_name: Name of the Qdrant collection to search
49-
limit: Maximum number of results to return
50-
score_threshold: Minimum similarity score threshold
51-
qdrant_url: Qdrant server URL
52-
qdrant_api_key: Authentication key for Qdrant
53-
"""
31+
class QdrantVectorSearchTool(BaseTool):
32+
"""Vector search tool for Qdrant."""
5433

5534
model_config = ConfigDict(arbitrary_types_allowed=True)
56-
client: QdrantClient = None # type: ignore[assignment]
35+
36+
# --- Metadata ---
5737
name: str = "QdrantVectorSearchTool"
58-
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
38+
description: str = "Search Qdrant vector DB for relevant documents."
5939
args_schema: type[BaseModel] = QdrantToolSchema
60-
query: str | None = None
61-
filter_by: str | None = None
62-
filter_value: str | None = None
63-
collection_name: str | None = None
64-
limit: int | None = Field(default=3)
65-
score_threshold: float = Field(default=0.35)
66-
qdrant_url: str = Field(
67-
...,
68-
description="The URL of the Qdrant server",
69-
)
70-
qdrant_api_key: str | None = Field(
71-
default=None,
72-
description="The API key for the Qdrant server",
73-
)
74-
custom_embedding_fn: Callable | None = Field(
75-
default=None,
76-
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
77-
)
7840
package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"])
7941
env_vars: list[EnvVar] = Field(
8042
default_factory=lambda: [
@@ -83,107 +45,81 @@ class QdrantVectorSearchTool(BaseTool):
8345
)
8446
]
8547
)
86-
87-
def __init__(self, **kwargs):
88-
super().__init__(**kwargs)
89-
if QDRANT_AVAILABLE:
90-
self.client = QdrantClient(
91-
url=self.qdrant_url,
92-
api_key=self.qdrant_api_key if self.qdrant_api_key else None,
48+
qdrant_config: QdrantConfig
49+
qdrant_package: ImportString[Any] = Field(
50+
default="qdrant_client",
51+
description="Base package path for Qdrant. Will dynamically import client and models.",
52+
)
53+
custom_embedding_fn: ImportString[Callable[[str], list[float]]] | None = Field(
54+
default=None,
55+
description="Optional embedding function or import path.",
56+
)
57+
client: Any | None = None
58+
59+
@model_validator(mode="after")
60+
def _setup_qdrant(self) -> QdrantVectorSearchTool:
61+
# Import the qdrant_package if it's a string
62+
if isinstance(self.qdrant_package, str):
63+
self.qdrant_package = importlib.import_module(self.qdrant_package)
64+
65+
if not self.client:
66+
self.client = self.qdrant_package.QdrantClient(
67+
url=self.qdrant_config.qdrant_url,
68+
api_key=self.qdrant_config.qdrant_api_key or None,
9369
)
94-
else:
95-
import click
96-
97-
if click.confirm(
98-
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
99-
"Would you like to install it?"
100-
):
101-
import subprocess
102-
103-
subprocess.run(["uv", "add", "qdrant-client"], check=True) # noqa: S607
104-
else:
105-
raise ImportError(
106-
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
107-
"Please install it with: uv add qdrant-client"
108-
)
70+
return self
10971

11072
def _run(
11173
self,
11274
query: str,
11375
filter_by: str | None = None,
114-
filter_value: str | None = None,
76+
filter_value: Any | None = None,
11577
) -> str:
116-
"""Execute vector similarity search on Qdrant.
117-
118-
Args:
119-
query: Search query to vectorize and match
120-
filter_by: Optional metadata field to filter on
121-
filter_value: Optional value to filter by
122-
123-
Returns:
124-
JSON string containing search results with metadata and scores
125-
126-
Raises:
127-
ImportError: If qdrant-client is not installed
128-
ValueError: If Qdrant credentials are missing
129-
"""
130-
if not self.qdrant_url:
131-
raise ValueError("QDRANT_URL is not set")
132-
133-
# Create filter if filter parameters are provided
134-
search_filter = None
135-
if filter_by and filter_value:
136-
search_filter = Filter(
78+
"""Perform vector similarity search."""
79+
filter_ = self.qdrant_package.http.models.Filter
80+
field_condition = self.qdrant_package.http.models.FieldCondition
81+
match_value = self.qdrant_package.http.models.MatchValue
82+
conditions = self.qdrant_config.filter_conditions.copy()
83+
if filter_by and filter_value is not None:
84+
conditions.append((filter_by, filter_value))
85+
86+
search_filter = (
87+
filter_(
13788
must=[
138-
FieldCondition(key=filter_by, match=MatchValue(value=filter_value))
89+
field_condition(key=k, match=match_value(value=v))
90+
for k, v in conditions
13991
]
14092
)
141-
142-
# Search in Qdrant using the built-in query method
93+
if conditions
94+
else None
95+
)
14396
query_vector = (
144-
self._vectorize_query(query, embedding_model="text-embedding-3-large")
145-
if not self.custom_embedding_fn
146-
else self.custom_embedding_fn(query)
97+
self.custom_embedding_fn(query)
98+
if self.custom_embedding_fn
99+
else (
100+
lambda: __import__("openai")
101+
.Client(api_key=os.getenv("OPENAI_API_KEY"))
102+
.embeddings.create(input=[query], model="text-embedding-3-large")
103+
.data[0]
104+
.embedding
105+
)()
147106
)
148-
search_results = self.client.query_points(
149-
collection_name=self.collection_name, # type: ignore[arg-type]
107+
results = self.client.query_points(
108+
collection_name=self.qdrant_config.collection_name,
150109
query=query_vector,
151110
query_filter=search_filter,
152-
limit=self.limit, # type: ignore[arg-type]
153-
score_threshold=self.score_threshold,
111+
limit=self.qdrant_config.limit,
112+
score_threshold=self.qdrant_config.score_threshold,
154113
)
155114

156-
# Format results similar to storage implementation
157-
results = []
158-
# Extract the list of ScoredPoint objects from the tuple
159-
for point in search_results:
160-
result = {
161-
"metadata": point[1][0].payload.get("metadata", {}),
162-
"context": point[1][0].payload.get("text", ""),
163-
"distance": point[1][0].score,
164-
}
165-
results.append(result)
166-
167-
return json.dumps(results, indent=2)
168-
169-
def _vectorize_query(self, query: str, embedding_model: str) -> list[float]:
170-
"""Default vectorization function with openai.
171-
172-
Args:
173-
query (str): The query to vectorize
174-
embedding_model (str): The embedding model to use
175-
176-
Returns:
177-
list[float]: The vectorized query
178-
"""
179-
import openai
180-
181-
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY"))
182-
return (
183-
client.embeddings.create(
184-
input=[query],
185-
model=embedding_model,
186-
)
187-
.data[0]
188-
.embedding
115+
return json.dumps(
116+
[
117+
{
118+
"distance": p.score,
119+
"metadata": p.payload.get("metadata", {}) if p.payload else {},
120+
"context": p.payload.get("text", "") if p.payload else {},
121+
}
122+
for p in results.points
123+
],
124+
indent=2,
189125
)

0 commit comments

Comments
 (0)