1- from collections .abc import Callable
1+ from __future__ import annotations
2+
3+ import importlib
24import json
35import os
6+ from collections .abc import Callable
47from typing import Any
58
6-
7- try :
8- from qdrant_client import QdrantClient
9- from qdrant_client .http .models import FieldCondition , Filter , MatchValue
10-
11- QDRANT_AVAILABLE = True
12- except ImportError :
13- QDRANT_AVAILABLE = False
14- QdrantClient = Any # type: ignore[assignment,misc] # type placeholder
15- Filter = Any # type: ignore[assignment,misc]
16- FieldCondition = Any # type: ignore[assignment,misc]
17- MatchValue = Any # type: ignore[assignment,misc]
18-
199from crewai .tools import BaseTool , EnvVar
20- from pydantic import BaseModel , ConfigDict , Field
10+ from pydantic import BaseModel , ConfigDict , Field , model_validator
11+ from pydantic .types import ImportString
2112
2213
2314class QdrantToolSchema (BaseModel ):
24- """Input for QdrantTool."""
15+ query : str = Field (..., description = "Query to search in Qdrant DB." )
16+ filter_by : str | None = None
17+ filter_value : str | None = None
2518
26- query : str = Field (
27- ...,
28- description = "The query to search retrieve relevant information from the Qdrant database. Pass only the query, not the question." ,
29- )
30- filter_by : str | None = Field (
31- default = None ,
32- description = "Filter by properties. Pass only the properties, not the question." ,
33- )
34- filter_value : str | None = Field (
35- default = None ,
36- description = "Filter by value. Pass only the value, not the question." ,
37- )
3819
20+ class QdrantConfig (BaseModel ):
21+ """All Qdrant connection and search settings."""
3922
40- class QdrantVectorSearchTool (BaseTool ):
41- """Tool to query and filter results from a Qdrant database.
23+ qdrant_url : str
24+ qdrant_api_key : str | None = None
25+ collection_name : str
26+ limit : int = 3
27+ score_threshold : float = 0.35
28+ filter_conditions : list [tuple [str , Any ]] = Field (default_factory = list )
4229
43- This tool enables vector similarity search on internal documents stored in Qdrant,
44- with optional filtering capabilities.
4530
46- Attributes:
47- client: Configured QdrantClient instance
48- collection_name: Name of the Qdrant collection to search
49- limit: Maximum number of results to return
50- score_threshold: Minimum similarity score threshold
51- qdrant_url: Qdrant server URL
52- qdrant_api_key: Authentication key for Qdrant
53- """
31+ class QdrantVectorSearchTool (BaseTool ):
32+ """Vector search tool for Qdrant."""
5433
5534 model_config = ConfigDict (arbitrary_types_allowed = True )
56- client : QdrantClient = None # type: ignore[assignment]
35+
36+ # --- Metadata ---
5737 name : str = "QdrantVectorSearchTool"
58- description : str = "A tool to search the Qdrant database for relevant information on internal documents."
38+ description : str = "Search Qdrant vector DB for relevant documents."
5939 args_schema : type [BaseModel ] = QdrantToolSchema
60- query : str | None = None
61- filter_by : str | None = None
62- filter_value : str | None = None
63- collection_name : str | None = None
64- limit : int | None = Field (default = 3 )
65- score_threshold : float = Field (default = 0.35 )
66- qdrant_url : str = Field (
67- ...,
68- description = "The URL of the Qdrant server" ,
69- )
70- qdrant_api_key : str | None = Field (
71- default = None ,
72- description = "The API key for the Qdrant server" ,
73- )
74- custom_embedding_fn : Callable | None = Field (
75- default = None ,
76- description = "A custom embedding function to use for vectorization. If not provided, the default model will be used." ,
77- )
7840 package_dependencies : list [str ] = Field (default_factory = lambda : ["qdrant-client" ])
7941 env_vars : list [EnvVar ] = Field (
8042 default_factory = lambda : [
@@ -83,107 +45,81 @@ class QdrantVectorSearchTool(BaseTool):
8345 )
8446 ]
8547 )
86-
87- def __init__ (self , ** kwargs ):
88- super ().__init__ (** kwargs )
89- if QDRANT_AVAILABLE :
90- self .client = QdrantClient (
91- url = self .qdrant_url ,
92- api_key = self .qdrant_api_key if self .qdrant_api_key else None ,
48+ qdrant_config : QdrantConfig
49+ qdrant_package : ImportString [Any ] = Field (
50+ default = "qdrant_client" ,
51+ description = "Base package path for Qdrant. Will dynamically import client and models." ,
52+ )
53+ custom_embedding_fn : ImportString [Callable [[str ], list [float ]]] | None = Field (
54+ default = None ,
55+ description = "Optional embedding function or import path." ,
56+ )
57+ client : Any | None = None
58+
59+ @model_validator (mode = "after" )
60+ def _setup_qdrant (self ) -> QdrantVectorSearchTool :
61+ # Import the qdrant_package if it's a string
62+ if isinstance (self .qdrant_package , str ):
63+ self .qdrant_package = importlib .import_module (self .qdrant_package )
64+
65+ if not self .client :
66+ self .client = self .qdrant_package .QdrantClient (
67+ url = self .qdrant_config .qdrant_url ,
68+ api_key = self .qdrant_config .qdrant_api_key or None ,
9369 )
94- else :
95- import click
96-
97- if click .confirm (
98- "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
99- "Would you like to install it?"
100- ):
101- import subprocess
102-
103- subprocess .run (["uv" , "add" , "qdrant-client" ], check = True ) # noqa: S607
104- else :
105- raise ImportError (
106- "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
107- "Please install it with: uv add qdrant-client"
108- )
70+ return self
10971
11072 def _run (
11173 self ,
11274 query : str ,
11375 filter_by : str | None = None ,
114- filter_value : str | None = None ,
76+ filter_value : Any | None = None ,
11577 ) -> str :
116- """Execute vector similarity search on Qdrant.
117-
118- Args:
119- query: Search query to vectorize and match
120- filter_by: Optional metadata field to filter on
121- filter_value: Optional value to filter by
122-
123- Returns:
124- JSON string containing search results with metadata and scores
125-
126- Raises:
127- ImportError: If qdrant-client is not installed
128- ValueError: If Qdrant credentials are missing
129- """
130- if not self .qdrant_url :
131- raise ValueError ("QDRANT_URL is not set" )
132-
133- # Create filter if filter parameters are provided
134- search_filter = None
135- if filter_by and filter_value :
136- search_filter = Filter (
78+ """Perform vector similarity search."""
79+ filter_ = self .qdrant_package .http .models .Filter
80+ field_condition = self .qdrant_package .http .models .FieldCondition
81+ match_value = self .qdrant_package .http .models .MatchValue
82+ conditions = self .qdrant_config .filter_conditions .copy ()
83+ if filter_by and filter_value is not None :
84+ conditions .append ((filter_by , filter_value ))
85+
86+ search_filter = (
87+ filter_ (
13788 must = [
138- FieldCondition (key = filter_by , match = MatchValue (value = filter_value ))
89+ field_condition (key = k , match = match_value (value = v ))
90+ for k , v in conditions
13991 ]
14092 )
141-
142- # Search in Qdrant using the built-in query method
93+ if conditions
94+ else None
95+ )
14396 query_vector = (
144- self ._vectorize_query (query , embedding_model = "text-embedding-3-large" )
145- if not self .custom_embedding_fn
146- else self .custom_embedding_fn (query )
97+ self .custom_embedding_fn (query )
98+ if self .custom_embedding_fn
99+ else (
100+ lambda : __import__ ("openai" )
101+ .Client (api_key = os .getenv ("OPENAI_API_KEY" ))
102+ .embeddings .create (input = [query ], model = "text-embedding-3-large" )
103+ .data [0 ]
104+ .embedding
105+ )()
147106 )
148- search_results = self .client .query_points (
149- collection_name = self .collection_name , # type: ignore[arg-type]
107+ results = self .client .query_points (
108+ collection_name = self .qdrant_config . collection_name ,
150109 query = query_vector ,
151110 query_filter = search_filter ,
152- limit = self .limit , # type: ignore[arg-type]
153- score_threshold = self .score_threshold ,
111+ limit = self .qdrant_config . limit ,
112+ score_threshold = self .qdrant_config . score_threshold ,
154113 )
155114
156- # Format results similar to storage implementation
157- results = []
158- # Extract the list of ScoredPoint objects from the tuple
159- for point in search_results :
160- result = {
161- "metadata" : point [1 ][0 ].payload .get ("metadata" , {}),
162- "context" : point [1 ][0 ].payload .get ("text" , "" ),
163- "distance" : point [1 ][0 ].score ,
164- }
165- results .append (result )
166-
167- return json .dumps (results , indent = 2 )
168-
169- def _vectorize_query (self , query : str , embedding_model : str ) -> list [float ]:
170- """Default vectorization function with openai.
171-
172- Args:
173- query (str): The query to vectorize
174- embedding_model (str): The embedding model to use
175-
176- Returns:
177- list[float]: The vectorized query
178- """
179- import openai
180-
181- client = openai .Client (api_key = os .getenv ("OPENAI_API_KEY" ))
182- return (
183- client .embeddings .create (
184- input = [query ],
185- model = embedding_model ,
186- )
187- .data [0 ]
188- .embedding
115+ return json .dumps (
116+ [
117+ {
118+ "distance" : p .score ,
119+ "metadata" : p .payload .get ("metadata" , {}) if p .payload else {},
120+ "context" : p .payload .get ("text" , "" ) if p .payload else {},
121+ }
122+ for p in results .points
123+ ],
124+ indent = 2 ,
189125 )
0 commit comments