-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathqdrant.py
More file actions
103 lines (90 loc) · 3.35 KB
/
qdrant.py
File metadata and controls
103 lines (90 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from LMStudioEmbedder import LMStudioEmbedder
import requests
import uuid
import os
#config
QDRANT_URL = os.environ["QDRANT_URL"]
class QdrantStore:
"""Stores and retrieves vectors in a Qdrant collection."""
def __init__(
self,
collection_name: str,
vector_size: int,
qdrant_url: str = QDRANT_URL,
qdrant_api_key: str | None = None,
distance: Distance = Distance.COSINE,
):
self.collection_name = collection_name
self.vector_size = vector_size
self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
self._ensure_collection(distance)
def _ensure_collection(self, distance: Distance):
"""Create the collection if it does not already exist."""
existing = [c.name for c in self.client.get_collections().collections]
if self.collection_name not in existing:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=distance),
)
print(f"Created collection '{self.collection_name}'")
else:
print(f"Using existing collection '{self.collection_name}'")
def upsert(
self,
vectors: list[list[float]],
documents: list[str],
metadata: list[dict] | None = None,
) -> int:
"""Upsert vectors with their source text and optional metadata."""
if metadata is None:
metadata = [{} for _ in documents]
points = [
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={"text": doc, **meta},
)
for vector, doc, meta in zip(vectors, documents, metadata)
]
self.client.upsert(collection_name=self.collection_name, points=points)
print(f"Upserted {len(points)} points into '{self.collection_name}'")
return len(points)
def search(
self,
query_vector: list[float],
top_k: int = 5,
score_threshold: float | None = None,
):
"""Search for similar vectors, returns Qdrant ScoredPoint results."""
return self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=top_k,
score_threshold=score_threshold,
)
def embed_documents_to_qdrant(
documents: list[str],
collection_name: str,
embedder: LMStudioEmbedder | None = None,
store: QdrantStore | None = None,
batch_size: int = 32,
metadata: list[dict] | None = None,
) -> int:
"""
Convenience function that wires LMStudioEmbedder and QdrantStore together.
Creates default instances if none are provided.
"""
embedder = embedder or LMStudioEmbedder()
store = store or QdrantStore(
collection_name=collection_name,
vector_size=embedder.vector_size,
)
total = len(documents)
all_vectors = []
for i in range(0, total, batch_size):
batch = documents[i : i + batch_size]
print(f"Embedding batch {i // batch_size + 1} ({i}-{min(i + batch_size, total) - 1})...")
all_vectors.extend(embedder.embed(batch))
return store.upsert(all_vectors, documents, metadata)