|
| 1 | +from qdrant_client import QdrantClient, models |
| 2 | +from typing import List, Dict, Any, Optional, Union |
| 3 | +import os |
| 4 | +from dotenv import load_dotenv |
| 5 | + |
| 6 | +load_dotenv() |
| 7 | + |
| 8 | + |
| 9 | +class QdrantDB: |
| 10 | + def __init__( |
| 11 | + self, url: str = "http://localhost:6333", api_key: Optional[str] = None |
| 12 | + ): |
| 13 | + """ |
| 14 | + Qdrant 클라이언트를 초기화합니다. |
| 15 | +
|
| 16 | + Args: |
| 17 | + url: Qdrant 서버 URL. |
| 18 | + api_key: Qdrant 클라우드 또는 인증된 인스턴스를 위한 API 키. |
| 19 | + """ |
| 20 | + self.client = QdrantClient(url=url, api_key=api_key) |
| 21 | + |
| 22 | + def create_collection( |
| 23 | + self, collection_name: str, dense_dim: int = 1536, colbert_dim: int = 128 |
| 24 | + ): |
| 25 | + """ |
| 26 | + Dense, ColBERT, Sparse 벡터 구성을 포함한 컬렉션을 생성합니다. |
| 27 | +
|
| 28 | + Args: |
| 29 | + collection_name: 생성할 컬렉션의 이름. |
| 30 | + dense_dim: Dense 벡터의 차원 (기본값: OpenAI small 모델 기준 1536). |
| 31 | + colbert_dim: ColBERT 벡터의 차원 (기본값: 128). |
| 32 | + """ |
| 33 | + if not self.client.collection_exists(collection_name): |
| 34 | + self.client.create_collection( |
| 35 | + collection_name=collection_name, |
| 36 | + vectors_config={ |
| 37 | + "dense": models.VectorParams( |
| 38 | + size=dense_dim, distance=models.Distance.COSINE |
| 39 | + ), |
| 40 | + "colbert": models.VectorParams( |
| 41 | + size=colbert_dim, |
| 42 | + distance=models.Distance.COSINE, |
| 43 | + multivector_config=models.MultiVectorConfig( |
| 44 | + comparator=models.MultiVectorComparator.MAX_SIM |
| 45 | + ), |
| 46 | + hnsw_config=models.HnswConfigDiff(m=0), |
| 47 | + ), |
| 48 | + }, |
| 49 | + sparse_vectors_config={"sparse": models.SparseVectorParams()}, |
| 50 | + ) |
| 51 | + print(f"Collection '{collection_name}' created.") |
| 52 | + else: |
| 53 | + print(f"Collection '{collection_name}' already exists.") |
| 54 | + |
| 55 | + def upsert(self, collection_name: str, points: List[Dict[str, Any]]): |
| 56 | + """ |
| 57 | + 컬렉션에 포인트들을 업서트(Upsert)합니다. |
| 58 | +
|
| 59 | + Args: |
| 60 | + collection_name: 컬렉션 이름. |
| 61 | + points: 다음 항목들을 포함하는 딕셔너리 리스트: |
| 62 | + - id: 고유 식별자 (int 또는 str) |
| 63 | + - vector: 'dense', 'colbert', 'sparse' 키와 해당 벡터 값을 포함하는 딕셔너리. |
| 64 | + - payload: 메타데이터를 포함하는 딕셔너리. |
| 65 | + """ |
| 66 | + point_structs = [] |
| 67 | + for point in points: |
| 68 | + if "id" not in point or "vector" not in point: |
| 69 | + raise ValueError("Each point must contain 'id' and 'vector' keys.") |
| 70 | + |
| 71 | + point_structs.append( |
| 72 | + models.PointStruct( |
| 73 | + id=point["id"], |
| 74 | + vector=point["vector"], |
| 75 | + payload=point.get("payload", {}), |
| 76 | + ) |
| 77 | + ) |
| 78 | + |
| 79 | + self.client.upload_points(collection_name=collection_name, points=point_structs) |
| 80 | + print( |
| 81 | + f"Successfully upserted {len(point_structs)} points to '{collection_name}'." |
| 82 | + ) |
0 commit comments