diff --git a/examples/image_search/main.py b/examples/image_search/main.py index ecd5a08dc..85d5fcbdc 100644 --- a/examples/image_search/main.py +++ b/examples/image_search/main.py @@ -12,13 +12,14 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from qdrant_client import QdrantClient +from typing import Any from PIL import Image from transformers import CLIPModel, CLIPProcessor QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/") -QDRANT_COLLECTION = "cocoindex_image_search" +QDRANT_COLLECTION = "ImageSearch" CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" CLIP_MODEL_DIMENSION = 768 @@ -78,15 +79,9 @@ def image_object_embedding_flow( embedding=img["embedding"], ) - qdrant_conn = cocoindex.add_auth_entry( - "Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL) - ) img_embeddings.export( "img_embeddings", - cocoindex.storages.Qdrant( - connection=qdrant_conn, - collection_name=QDRANT_COLLECTION, - ), + cocoindex.storages.Qdrant(collection_name=QDRANT_COLLECTION), primary_key_fields=["id"], ) @@ -106,7 +101,7 @@ def image_object_embedding_flow( # --- CocoIndex initialization on startup --- @app.on_event("startup") -def startup_event(): +def startup_event() -> None: load_dotenv() cocoindex.init() # Initialize Qdrant client @@ -119,7 +114,7 @@ def startup_event(): def search( q: str = Query(..., description="Search query"), limit: int = Query(5, description="Number of results"), -): +) -> Any: # Get the embedding for the query query_embedding = embed_query(q) diff --git a/examples/text_embedding_qdrant/main.py b/examples/text_embedding_qdrant/main.py index d7a345f0f..b63ddfe64 100644 --- a/examples/text_embedding_qdrant/main.py +++ b/examples/text_embedding_qdrant/main.py @@ -1,11 +1,10 @@ from dotenv import load_dotenv from qdrant_client import QdrantClient -from qdrant_client.http.models import Filter, FieldCondition, MatchValue import cocoindex # Define Qdrant connection constants QDRANT_URL = "http://localhost:6334" -QDRANT_COLLECTION = "cocoindex_text_embedding" +QDRANT_COLLECTION = "TextEmbedding" @cocoindex.transform_flow() @@ -55,15 +54,9 @@ def text_embedding_flow( text_embedding=chunk["embedding"], ) - qdrant_conn = cocoindex.add_auth_entry( - "Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL) - ) doc_embeddings.export( "doc_embeddings", - cocoindex.storages.Qdrant( - connection=qdrant_conn, - collection_name=QDRANT_COLLECTION, - ), + cocoindex.storages.Qdrant(collection_name=QDRANT_COLLECTION), primary_key_fields=["id"], ) diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index 1f9873bf0..e77a8cdc3 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -20,7 +20,7 @@ class Postgres(op.StorageSpec): class QdrantConnection: """Connection spec for Qdrant.""" - url: str + grpc_url: str api_key: str | None = None @@ -28,8 +28,8 @@ class QdrantConnection: class Qdrant(op.StorageSpec): """Storage powered by Qdrant - https://qdrant.tech/.""" - connection: AuthEntryReference[QdrantConnection] collection_name: str + connection: AuthEntryReference[QdrantConnection] | None = None @dataclass diff --git a/src/ops/storages/qdrant.rs b/src/ops/storages/qdrant.rs index 88964f228..276003a85 100644 --- a/src/ops/storages/qdrant.rs +++ b/src/ops/storages/qdrant.rs @@ -24,13 +24,13 @@ const DEFAULT_URL: &str = "http://localhost:6334/"; #[derive(Debug, Deserialize, Clone)] pub struct ConnectionSpec { - url: String, + grpc_url: String, api_key: Option, } #[derive(Debug, Deserialize, Clone)] struct Spec { - connection: spec::AuthEntryReference, + connection: Option>, collection_name: String, } @@ -150,10 +150,14 @@ impl setup::ResourceSetupStatus for SetupStatus { } impl SetupStatus { - async fn apply(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> { + async fn apply_delete(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> { if self.delete_collection { qdrant_client.delete_collection(collection_name).await?; } + Ok(()) + } + + async fn apply_create(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> { if let Some(add_collection) = &self.add_collection { let mut builder = CreateCollectionBuilder::new(collection_name); if !add_collection.vectors.is_empty() { @@ -382,10 +386,9 @@ impl StorageFactoryBase for Factory { } } - let connection = Some(d.spec.connection); let export_context = Arc::new(ExportContext { qdrant_client: self - .get_qdrant_client(&connection, &context.auth_registry)?, + .get_qdrant_client(&d.spec.connection, &context.auth_registry)?, collection_name: d.spec.collection_name.clone(), fields_info, }); @@ -398,7 +401,7 @@ impl StorageFactoryBase for Factory { Ok(TypedExportDataCollectionBuildOutput { executors: executors.boxed(), setup_key: CollectionKey { - connection, + connection: d.spec.connection, collection_name: d.spec.collection_name, }, desired_setup_state: SetupState { @@ -489,12 +492,20 @@ impl StorageFactoryBase for Factory { setup_status: Vec>, auth_registry: &Arc, ) -> Result<()> { - for setup_change in setup_status.into_iter() { + for setup_change in setup_status.iter() { + let qdrant_client = + self.get_qdrant_client(&setup_change.key.connection, auth_registry)?; + setup_change + .setup_status + .apply_delete(&setup_change.key.collection_name, &qdrant_client) + .await?; + } + for setup_change in setup_status.iter() { let qdrant_client = self.get_qdrant_client(&setup_change.key.connection, auth_registry)?; setup_change .setup_status - .apply(&setup_change.key.collection_name, &qdrant_client) + .apply_create(&setup_change.key.collection_name, &qdrant_client) .await?; } Ok(()) @@ -521,14 +532,14 @@ impl Factory { let spec = auth_entry.as_ref().map_or_else( || { Ok(ConnectionSpec { - url: DEFAULT_URL.to_string(), + grpc_url: DEFAULT_URL.to_string(), api_key: None, }) }, |auth_entry| auth_registry.get(auth_entry), )?; let client = Arc::new( - Qdrant::from_url(&spec.url) + Qdrant::from_url(&spec.grpc_url) .api_key(spec.api_key) .skip_compatibility_check() .build()?,