Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions examples/image_search/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from qdrant_client import QdrantClient
from typing import Any

from PIL import Image
from transformers import CLIPModel, CLIPProcessor


QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/")
QDRANT_COLLECTION = "cocoindex_image_search"
QDRANT_COLLECTION = "ImageSearch"
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
CLIP_MODEL_DIMENSION = 768

Expand Down Expand Up @@ -78,15 +79,9 @@ def image_object_embedding_flow(
embedding=img["embedding"],
)

qdrant_conn = cocoindex.add_auth_entry(
"Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL)
)
img_embeddings.export(
"img_embeddings",
cocoindex.storages.Qdrant(
connection=qdrant_conn,
collection_name=QDRANT_COLLECTION,
),
cocoindex.storages.Qdrant(collection_name=QDRANT_COLLECTION),
primary_key_fields=["id"],
)

Expand All @@ -106,7 +101,7 @@ def image_object_embedding_flow(

# --- CocoIndex initialization on startup ---
@app.on_event("startup")
def startup_event():
def startup_event() -> None:
load_dotenv()
cocoindex.init()
# Initialize Qdrant client
Expand All @@ -119,7 +114,7 @@ def startup_event():
def search(
q: str = Query(..., description="Search query"),
limit: int = Query(5, description="Number of results"),
):
) -> Any:
# Get the embedding for the query
query_embedding = embed_query(q)

Expand Down
11 changes: 2 additions & 9 deletions examples/text_embedding_qdrant/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
import cocoindex

# Define Qdrant connection constants
QDRANT_URL = "http://localhost:6334"
QDRANT_COLLECTION = "cocoindex_text_embedding"
QDRANT_COLLECTION = "TextEmbedding"


@cocoindex.transform_flow()
Expand Down Expand Up @@ -55,15 +54,9 @@ def text_embedding_flow(
text_embedding=chunk["embedding"],
)

qdrant_conn = cocoindex.add_auth_entry(
"Qdrant", cocoindex.storages.QdrantConnection(url=QDRANT_URL)
)
doc_embeddings.export(
"doc_embeddings",
cocoindex.storages.Qdrant(
connection=qdrant_conn,
collection_name=QDRANT_COLLECTION,
),
cocoindex.storages.Qdrant(collection_name=QDRANT_COLLECTION),
primary_key_fields=["id"],
)

Expand Down
4 changes: 2 additions & 2 deletions python/cocoindex/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ class Postgres(op.StorageSpec):
class QdrantConnection:
"""Connection spec for Qdrant."""

url: str
grpc_url: str
api_key: str | None = None


@dataclass
class Qdrant(op.StorageSpec):
"""Storage powered by Qdrant - https://qdrant.tech/."""

connection: AuthEntryReference[QdrantConnection]
collection_name: str
connection: AuthEntryReference[QdrantConnection] | None = None


@dataclass
Expand Down
31 changes: 21 additions & 10 deletions src/ops/storages/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ const DEFAULT_URL: &str = "http://localhost:6334/";

#[derive(Debug, Deserialize, Clone)]
pub struct ConnectionSpec {
url: String,
grpc_url: String,
api_key: Option<String>,
}

#[derive(Debug, Deserialize, Clone)]
struct Spec {
connection: spec::AuthEntryReference<ConnectionSpec>,
connection: Option<spec::AuthEntryReference<ConnectionSpec>>,
collection_name: String,
}

Expand Down Expand Up @@ -150,10 +150,14 @@ impl setup::ResourceSetupStatus for SetupStatus {
}

impl SetupStatus {
async fn apply(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> {
async fn apply_delete(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> {
if self.delete_collection {
qdrant_client.delete_collection(collection_name).await?;
}
Ok(())
}

async fn apply_create(&self, collection_name: &String, qdrant_client: &Qdrant) -> Result<()> {
if let Some(add_collection) = &self.add_collection {
let mut builder = CreateCollectionBuilder::new(collection_name);
if !add_collection.vectors.is_empty() {
Expand Down Expand Up @@ -382,10 +386,9 @@ impl StorageFactoryBase for Factory {
}
}

let connection = Some(d.spec.connection);
let export_context = Arc::new(ExportContext {
qdrant_client: self
.get_qdrant_client(&connection, &context.auth_registry)?,
.get_qdrant_client(&d.spec.connection, &context.auth_registry)?,
collection_name: d.spec.collection_name.clone(),
fields_info,
});
Expand All @@ -398,7 +401,7 @@ impl StorageFactoryBase for Factory {
Ok(TypedExportDataCollectionBuildOutput {
executors: executors.boxed(),
setup_key: CollectionKey {
connection,
connection: d.spec.connection,
collection_name: d.spec.collection_name,
},
desired_setup_state: SetupState {
Expand Down Expand Up @@ -489,12 +492,20 @@ impl StorageFactoryBase for Factory {
setup_status: Vec<TypedResourceSetupChangeItem<'async_trait, Self>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()> {
for setup_change in setup_status.into_iter() {
for setup_change in setup_status.iter() {
let qdrant_client =
self.get_qdrant_client(&setup_change.key.connection, auth_registry)?;
setup_change
.setup_status
.apply_delete(&setup_change.key.collection_name, &qdrant_client)
.await?;
}
for setup_change in setup_status.iter() {
let qdrant_client =
self.get_qdrant_client(&setup_change.key.connection, auth_registry)?;
setup_change
.setup_status
.apply(&setup_change.key.collection_name, &qdrant_client)
.apply_create(&setup_change.key.collection_name, &qdrant_client)
.await?;
}
Ok(())
Expand All @@ -521,14 +532,14 @@ impl Factory {
let spec = auth_entry.as_ref().map_or_else(
|| {
Ok(ConnectionSpec {
url: DEFAULT_URL.to_string(),
grpc_url: DEFAULT_URL.to_string(),
api_key: None,
})
},
|auth_entry| auth_registry.get(auth_entry),
)?;
let client = Arc::new(
Qdrant::from_url(&spec.url)
Qdrant::from_url(&spec.grpc_url)
.api_key(spec.api_key)
.skip_compatibility_check()
.build()?,
Expand Down