Skip to content

Commit 924c093

Browse files
committed
feat: Qdrant support
Signed-off-by: Anush008 <anushshetty90@gmail.com>
1 parent 66852c0 commit 924c093

File tree

7 files changed

+523
-271
lines changed

7 files changed

+523
-271
lines changed

aios/config/config.yaml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ llms:
5454
# backend: "huggingface"
5555
# max_gpu_memory: {0: "48GB"} # GPU memory allocation
5656
# eval_device: "cuda:0" # Device for model evaluation
57-
57+
5858
# vLLM Models
5959
# To use vllm as backend, you need to install vllm and run the vllm server https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
6060
# An example command to run the vllm server is:
@@ -72,10 +72,18 @@ llms:
7272

7373
memory:
7474
log_mode: "console" # choose from [console, file]
75-
75+
7676
storage:
7777
root_dir: "root"
7878
use_vector_db: true
79+
# vector DB backend: chroma | qdrant
80+
vector_db_backend: "chroma"
81+
# Qdrant connection (used when backend == qdrant). Can also use env vars QDRANT_HOST, QDRANT_PORT, QDRANT_API_KEY
82+
qdrant_host: "localhost"
83+
qdrant_port: 6333
84+
qdrant_api_key: ""
85+
# Embedding model for Qdrant fastembed integration
86+
qdrant_model_name: "sentence-transformers/all-MiniLM-L6-v2"
7987

8088
tool:
8189
mcp_server_script_path: "aios/tool/mcp_server.py"
@@ -85,8 +93,8 @@ scheduler:
8593

8694
agent_factory:
8795
log_mode: "console" # choose from [console, file]
88-
max_workers: 64
89-
96+
max_workers: 64
97+
9098
server:
9199
host: "localhost"
92100
port: 8000

aios/llm_core/routing.py

Lines changed: 90 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from enum import Enum
21
from typing import List, Dict, Any
32
import chromadb
43
from chromadb.utils import embedding_functions
4+
from qdrant_client import QdrantClient, models
5+
from aios.config.config_manager import config as global_config
56
import json
67
import numpy as np
7-
from typing import List, Dict, Any
88
from tqdm import tqdm
99
from collections import defaultdict
1010

11-
import json
11+
import uuid
1212

1313
from threading import Lock
1414

@@ -52,8 +52,8 @@ class RouterStrategy:
5252

5353
class SequentialRouting:
5454
"""
55-
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56-
It iterates through a list of selected language models and returns their corresponding index based on
55+
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56+
It iterates through a list of selected language models and returns their corresponding index based on
5757
the request count.
5858
5959
This strategy ensures that multiple models are utilized in sequence, distributing queries evenly across the available configurations.
@@ -98,24 +98,24 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
9898
"""
9999
# current = self.selected_llms[self.idx]
100100
model_idxs = []
101-
101+
102102
available_models = [llm.name for llm in self.llm_configs]
103-
103+
104104
n_queries = len(queries)
105-
105+
106106
for i in range(n_queries):
107107
selected_llm_list = selected_llm_lists[i]
108-
108+
109109
if not selected_llm_list or len(selected_llm_list) == 0:
110110
model_idxs.append(0)
111111
continue
112-
112+
113113
model_idx = -1
114114
for selected_llm in selected_llm_list:
115115
if selected_llm["name"] in available_models:
116116
model_idx = available_models.index(selected_llm["name"])
117117
break
118-
118+
119119
model_idxs.append(model_idx)
120120

121121
return model_idxs
@@ -138,7 +138,7 @@ def get_token_lengths(queries: List[List[Dict[str, Any]]]):
138138
return [token_counter(model="gpt-4o-mini", messages=query) for query in queries]
139139

140140
def messages_to_query(messages: List[Dict[str, str]],
141-
strategy: str = "last_user") -> str:
141+
strategy: str = "last_user") -> str:
142142
"""
143143
Convert OpenAI ChatCompletion-style messages into a single query string.
144144
strategy:
@@ -201,43 +201,66 @@ def __init__(self,
201201
model_name: str = "all-MiniLM-L6-v2",
202202
persist_directory: str = "llm_router",
203203
bootstrap_url: str | None = None):
204+
storage_cfg = global_config.get_storage_config() or {}
205+
backend = (storage_cfg.get("vector_db_backend") or os.environ.get("VECTOR_DB_BACKEND") or "chroma").lower()
206+
204207
self._persist_root = os.path.join(os.path.dirname(__file__), persist_directory)
205208
os.makedirs(self._persist_root, exist_ok=True)
206209

207-
self.client = chromadb.PersistentClient(path=self._persist_root)
208-
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
210+
self.backend = backend
211+
if backend == "qdrant":
212+
host = storage_cfg.get("qdrant_host", os.environ.get("QDRANT_HOST", "localhost"))
213+
port = int(storage_cfg.get("qdrant_port", os.environ.get("QDRANT_PORT", 6333)))
214+
api_key = storage_cfg.get("qdrant_api_key", os.environ.get("QDRANT_API_KEY"))
215+
self.qdrant = QdrantClient(host=host, port=port, api_key=api_key)
216+
self.model_name = storage_cfg.get("qdrant_model_name") or os.environ.get("QDRANT_EMBEDDING_MODEL", model_name)
217+
self.collection_name = "historical_queries"
218+
self._ensure_qdrant_collection()
219+
else:
220+
self.client = chromadb.PersistentClient(path=self._persist_root)
221+
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
222+
self.collection = self._get_or_create_collection("historical_queries")
209223

210-
# Always create/get collections up‑front so we can inspect counts.
211-
# self.train_collection = self._get_or_create_collection("train_queries")
212-
# self.val_collection = self._get_or_create_collection("val_queries")
213-
# self.test_collection = self._get_or_create_collection("test_queries")
214-
self.collection = self._get_or_create_collection("historical_queries")
215-
216224
# If DB is empty and we have a bootstrap URL – populate it.
217-
if bootstrap_url and self.collection.count() == 0:
218-
self._bootstrap_from_drive(bootstrap_url)
219-
220-
# .................................................................
221-
# Chroma helpers
222-
# .................................................................
225+
if bootstrap_url:
226+
if backend == "qdrant":
227+
count = self._qdrant_count()
228+
if count == 0:
229+
self._bootstrap_from_drive(bootstrap_url)
230+
else:
231+
if self.collection.count() == 0:
232+
self._bootstrap_from_drive(bootstrap_url)
223233

224234
def _get_or_create_collection(self, name: str):
235+
if self.backend == "qdrant":
236+
return None
225237
try:
226238
return self.client.get_collection(name=name, embedding_function=self.embedding_function)
227239
except Exception:
228240
return self.client.create_collection(name=name, embedding_function=self.embedding_function)
229241

230-
# .................................................................
231-
# Bootstrap logic – download + ingest
232-
# .................................................................
242+
def _ensure_qdrant_collection(self):
243+
if not self.qdrant.collection_exists(self.collection_name):
244+
dim = self.qdrant.get_embedding_size(self.model_name)
245+
self.qdrant.create_collection(
246+
self.collection_name,
247+
vectors_config=models.VectorParams(size=dim, distance=models.Distance.COSINE),
248+
)
249+
250+
def _qdrant_count(self) -> int:
251+
try:
252+
count = self.qdrant.count(self.collection_name).count
253+
return count
254+
except Exception:
255+
return 0
233256

234257
def _bootstrap_from_drive(self, url_or_id: str):
235258
print("\n[SmartRouting] Bootstrapping ChromaDB from Google Drive…\n")
236259

237260
with tempfile.TemporaryDirectory() as tmp:
238261
# NB: gdown accepts both share links and raw IDs.
239262
local_path = os.path.join(tmp, "bootstrap.json")
240-
263+
241264
gdown.download(url_or_id, local_path, quiet=False, fuzzy=True)
242265

243266
# Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
@@ -249,16 +272,16 @@ def _bootstrap_from_drive(self, url_or_id: str):
249272

250273
print("[SmartRouting] Bootstrap complete – collections populated.\n")
251274

252-
# .................................................................
253-
# Public data API
254-
# .................................................................
255-
256275
def add_data(self, data: List[Dict[str, Any]]):
257-
collection = self.collection
258-
queries, metadatas, ids = [], [], []
276+
if self.backend == "qdrant":
277+
queries, metadatas, ids = [], [], []
278+
else:
279+
collection = self.collection
280+
queries, metadatas, ids = [], [], []
281+
259282
correct_count = total_count = 0
260283

261-
for idx, item in enumerate(tqdm(data, desc=f"Ingesting historical queries")):
284+
for idx, item in enumerate(tqdm(data, desc="Ingesting historical queries")):
262285
query = item["query"]
263286
model_metadatas = item["outputs"]
264287
for model_metadata in model_metadatas:
@@ -275,15 +298,39 @@ def add_data(self, data: List[Dict[str, Any]]):
275298
metadatas.append(meta)
276299
ids.append(f"{idx}")
277300

278-
collection.add(documents=queries, metadatas=metadatas, ids=ids)
279-
print(f"[SmartRouting]: {total_count} historical queries ingested.")
301+
if self.backend == "qdrant":
302+
docs = [models.Document(text=q, model=self.model_name) for q in queries]
303+
if docs:
304+
# Deterministic UUIDv5 ids; store original in payload
305+
q_ids = [str(uuid.uuid5(uuid.NAMESPACE_URL, i)) for i in ids]
306+
for i, meta in enumerate(metadatas):
307+
meta["original_id"] = ids[i]
308+
self.qdrant.upload_collection(
309+
collection_name=self.collection_name,
310+
vectors=docs,
311+
ids=q_ids,
312+
payload=metadatas,
313+
)
314+
else:
315+
collection.add(documents=queries, metadatas=metadatas, ids=ids)
316+
print(f"[SmartRouting]: {total_count} historical queries ingested.")
280317

281-
# ..................................................................
282318
def query_similar(self, query: str | List[str], n_results: int = 16):
319+
if self.backend == "qdrant":
320+
qtext = query if isinstance(query, str) else query[0]
321+
results = self.qdrant.query_points(
322+
collection_name=self.collection_name,
323+
query=models.Document(text=qtext, model=self.model_name),
324+
limit=n_results,
325+
).points
326+
return {
327+
"ids": [[(r.payload or {}).get("original_id", str(r.id)) for r in results]],
328+
"metadatas": [[(r.payload or {}) for r in results]],
329+
"documents": [[""] * len(results)],
330+
}
283331
collection = self.collection
284332
return collection.query(query_texts=query if isinstance(query, list) else [query], n_results=n_results)
285333

286-
# ..................................................................
287334
def predict(self, query: str | List[str], model_configs: List[Dict[str, Any]], n_similar: int = 16):
288335
similar = self.query_similar(query, n_results=n_similar)
289336
perf_mat, len_mat = [], []
@@ -355,7 +402,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
355402

356403
input_lens = get_token_lengths(queries)
357404
chosen_indices: list[int] = []
358-
405+
359406
converted_queries = [messages_to_query(query) for query in queries]
360407

361408
for q, q_len, candidate_cfgs in zip(converted_queries, input_lens, selected_llm_lists):
@@ -376,7 +423,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
376423

377424
# Map back to global llm_configs index
378425
sel_name = candidate_cfgs[sel_local_idx]["name"]
379-
426+
380427
sel_idx = self.available_models.index(sel_name)
381428
chosen_indices.append(sel_idx)
382429

0 commit comments

Comments
 (0)