Skip to content

Commit 2dfa5e8

Browse files
authored
feat(qdrant): support automatic setup for Qdrant (#577)
* cleanup(qdrant): remove query related logic for qdrant * refactor(qdrant): prepare vector info ahead of time * feat(qdrant): support automatic setup
1 parent b2d6de5 commit 2dfa5e8

File tree

6 files changed

+404
-273
lines changed

6 files changed

+404
-273
lines changed

examples/image_search/main.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from transformers import CLIPModel, CLIPProcessor
1818

1919

20-
QDRANT_GRPC_URL = os.getenv("QDRANT_GRPC_URL", "http://localhost:6334/")
20+
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/")
21+
QDRANT_COLLECTION = "cocoindex_image_search"
2122
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
23+
CLIP_MODEL_DIMENSION = 768
2224

2325

2426
@functools.cache
@@ -40,7 +42,9 @@ def embed_query(text: str) -> list[float]:
4042

4143

4244
@cocoindex.op.function(cache=True, behavior_version=1, gpu=True)
43-
def embed_image(img_bytes: bytes) -> cocoindex.Vector[cocoindex.Float32, Literal[384]]:
45+
def embed_image(
46+
img_bytes: bytes,
47+
) -> cocoindex.Vector[cocoindex.Float32, Literal[CLIP_MODEL_DIMENSION]]:
4448
"""
4549
Convert image to embedding using CLIP model.
4650
"""
@@ -56,7 +60,7 @@ def embed_image(img_bytes: bytes) -> cocoindex.Vector[cocoindex.Float32, Literal
5660
@cocoindex.flow_def(name="ImageObjectEmbedding")
5761
def image_object_embedding_flow(
5862
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
59-
):
63+
) -> None:
6064
data_scope["images"] = flow_builder.add_source(
6165
cocoindex.sources.LocalFile(
6266
path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True
@@ -73,14 +77,17 @@ def image_object_embedding_flow(
7377
filename=img["filename"],
7478
embedding=img["embedding"],
7579
)
80+
81+
qdrant_conn = cocoindex.add_auth_entry(
82+
"Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL)
83+
)
7684
img_embeddings.export(
7785
"img_embeddings",
7886
cocoindex.storages.Qdrant(
79-
collection_name="image_search",
80-
grpc_url=QDRANT_GRPC_URL,
87+
connection=qdrant_conn,
88+
collection_name=QDRANT_COLLECTION,
8189
),
8290
primary_key_fields=["id"],
83-
setup_by_user=True,
8491
)
8592

8693

@@ -103,7 +110,7 @@ def startup_event():
103110
load_dotenv()
104111
cocoindex.init()
105112
# Initialize Qdrant client
106-
app.state.qdrant_client = QdrantClient(url=QDRANT_GRPC_URL, prefer_grpc=True)
113+
app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=True)
107114
app.state.live_updater = cocoindex.FlowLiveUpdater(image_object_embedding_flow)
108115
app.state.live_updater.start()
109116

@@ -118,7 +125,7 @@ def search(
118125

119126
# Search in Qdrant
120127
search_results = app.state.qdrant_client.search(
121-
collection_name="image_search",
128+
collection_name=QDRANT_COLLECTION,
122129
query_vector=("embedding", query_embedding),
123130
limit=limit,
124131
)

examples/text_embedding_qdrant/main.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import cocoindex
55

66
# Define Qdrant connection constants
7-
QDRANT_GRPC_URL = "http://localhost:6334"
8-
QDRANT_COLLECTION = "cocoindex"
7+
QDRANT_URL = "http://localhost:6334"
8+
QDRANT_COLLECTION = "cocoindex_text_embedding"
99

1010

1111
@cocoindex.transform_flow()
@@ -26,7 +26,7 @@ def text_to_embedding(
2626
@cocoindex.flow_def(name="TextEmbeddingWithQdrant")
2727
def text_embedding_flow(
2828
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
29-
):
29+
) -> None:
3030
"""
3131
Define an example flow that embeds text into a vector database.
3232
"""
@@ -55,19 +55,22 @@ def text_embedding_flow(
5555
text_embedding=chunk["embedding"],
5656
)
5757

58+
qdrant_conn = cocoindex.add_auth_entry(
59+
"Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL)
60+
)
5861
doc_embeddings.export(
5962
"doc_embeddings",
6063
cocoindex.storages.Qdrant(
61-
collection_name=QDRANT_COLLECTION, grpc_url=QDRANT_GRPC_URL
64+
connection=qdrant_conn,
65+
collection_name=QDRANT_COLLECTION,
6266
),
6367
primary_key_fields=["id"],
64-
setup_by_user=True,
6568
)
6669

6770

68-
def _main():
71+
def _main() -> None:
6972
# Initialize Qdrant client
70-
client = QdrantClient(url=QDRANT_GRPC_URL, prefer_grpc=True)
73+
client = QdrantClient(url=QDRANT_URL, prefer_grpc=True)
7174

7275
# Run queries in a loop to demonstrate the query capabilities.
7376
while True:
@@ -87,6 +90,8 @@ def _main():
8790
for result in search_results:
8891
score = result.score
8992
payload = result.payload
93+
if payload is None:
94+
continue
9095
print(f"[{score:.3f}] {payload['filename']}")
9196
print(f" {payload['text']}")
9297
print("---")

python/cocoindex/storages.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,20 @@ class Postgres(op.StorageSpec):
1616
table_name: str | None = None
1717

1818

19+
@dataclass
20+
class QdrantConnection:
21+
"""Connection spec for Qdrant."""
22+
23+
url: str
24+
api_key: str | None = None
25+
26+
1927
@dataclass
2028
class Qdrant(op.StorageSpec):
2129
"""Storage powered by Qdrant - https://qdrant.tech/."""
2230

31+
connection: AuthEntryReference[QdrantConnection]
2332
collection_name: str
24-
grpc_url: str = "http://localhost:6334/"
25-
api_key: str | None = None
2633

2734

2835
@dataclass

src/ops/factory_bases.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static {
309309
Vec<(Self::Key, Self::SetupState)>,
310310
)>;
311311

312+
/// Deserialize the setup key from a JSON value.
313+
/// You can override this method to provide a custom deserialization logic, e.g. to perform backward compatible deserialization.
314+
fn deserialize_setup_key(key: serde_json::Value) -> Result<Self::Key> {
315+
Ok(serde_json::from_value(key)?)
316+
}
317+
312318
/// Will not be called if it's setup by user.
313319
/// It returns an error if the target only supports setup by user.
314320
async fn check_setup_status(
@@ -421,7 +427,7 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
421427
existing_states: setup::CombinedState<serde_json::Value>,
422428
auth_registry: &Arc<AuthRegistry>,
423429
) -> Result<Box<dyn setup::ResourceSetupStatus>> {
424-
let key: T::Key = serde_json::from_value(key.clone())?;
430+
let key: T::Key = Self::deserialize_setup_key(key.clone())?;
425431
let desired_state: Option<T::SetupState> = desired_state
426432
.map(|v| serde_json::from_value(v.clone()))
427433
.transpose()?;
@@ -438,12 +444,12 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
438444
}
439445

440446
fn describe_resource(&self, key: &serde_json::Value) -> Result<String> {
441-
let key: T::Key = serde_json::from_value(key.clone())?;
447+
let key: T::Key = Self::deserialize_setup_key(key.clone())?;
442448
StorageFactoryBase::describe_resource(self, &key)
443449
}
444450

445451
fn normalize_setup_key(&self, key: &serde_json::Value) -> Result<serde_json::Value> {
446-
let key: T::Key = serde_json::from_value(key.clone())?;
452+
let key: T::Key = Self::deserialize_setup_key(key.clone())?;
447453
Ok(serde_json::to_value(key)?)
448454
}
449455

src/ops/registration.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use super::{
33
storages,
44
};
55
use anyhow::Result;
6-
use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard};
6+
use std::sync::{LazyLock, RwLock, RwLockReadGuard};
77

88
fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
99
let reqwest_client = reqwest::Client::new();
@@ -17,7 +17,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
1717
functions::extract_by_llm::Factory.register(registry)?;
1818

1919
storages::postgres::Factory::default().register(registry)?;
20-
Arc::new(storages::qdrant::Factory::default()).register(registry)?;
20+
storages::qdrant::register(registry)?;
2121
storages::kuzu::register(registry, reqwest_client)?;
2222

2323
storages::neo4j::Factory::new().register(registry)?;

0 commit comments

Comments
 (0)