Skip to content

Commit 0361517

Browse files
committed
feat: Qdrant support
Signed-off-by: Anush008 <anushshetty90@gmail.com>
1 parent 66852c0 commit 0361517

File tree

10 files changed

+751
-281
lines changed

10 files changed

+751
-281
lines changed

.github/workflows/test-ollama.yml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,23 @@ jobs:
6161
echo "Waiting for Kernel… ($i/${KERNEL_MAX_WAIT})"; sleep 1
6262
done
6363
64-
- name: Run tests
64+
- name: Run Ollama tests
6565
run: |
6666
source .venv/bin/activate
6767
mkdir -p test_results
68-
mapfile -t TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py")
69-
if [ "${#TESTS[@]}" -eq 0 ]; then
68+
69+
# Run Ollama tests
70+
mapfile -t OLLAMA_TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py")
71+
if [ "${#OLLAMA_TESTS[@]}" -eq 0 ]; then
7072
echo "⚠️ No llm/ollama tests found – skipping."
71-
exit 0
73+
else
74+
for t in "${OLLAMA_TESTS[@]}"; do
75+
echo "▶️ Running Ollama test: $t"
76+
python -m unittest "$t" 2>&1 | tee -a test_results/ollama_tests.log
77+
echo "----------------------------------------"
78+
done
7279
fi
73-
for t in "${TESTS[@]}"; do
74-
echo "▶️ Running $t"
75-
python -m unittest "$t" 2>&1 | tee -a test_results/ollama_tests.log
76-
echo "----------------------------------------"
77-
done
80+
7881

7982
- name: Upload logs
8083
if: always()

.github/workflows/test-qdrant.yml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
name: Test Qdrant Integration
2+
3+
on:
4+
pull_request:
5+
branches: [ "main" ]
6+
push:
7+
branches: [ "main" ]
8+
9+
permissions:
10+
contents: read
11+
actions: write
12+
13+
concurrency:
14+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
15+
cancel-in-progress: true
16+
17+
jobs:
18+
test-qdrant:
19+
runs-on: ubuntu-latest
20+
21+
services:
22+
qdrant:
23+
image: qdrant/qdrant:latest
24+
ports:
25+
- 6333:6333
26+
27+
env:
28+
VECTOR_DB_BACKEND: qdrant
29+
QDRANT_HOST: localhost
30+
QDRANT_PORT: 6333
31+
QDRANT_EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2
32+
33+
steps:
34+
- uses: actions/checkout@v4
35+
36+
- name: Set up Python 3.10
37+
uses: actions/setup-python@v5
38+
with:
39+
python-version: "3.10"
40+
41+
- name: Install dependencies
42+
run: |
43+
python -m pip install uv
44+
uv venv
45+
source .venv/bin/activate
46+
uv pip install -r requirements.txt
47+
# Install additional dependencies for Qdrant tests
48+
uv pip install qdrant-client
49+
50+
- name: Run Qdrant integration tests
51+
run: |
52+
source .venv/bin/activate
53+
mkdir -p test_results
54+
55+
echo "Running Qdrant integration tests"
56+
python -m unittest tests.modules.test_qdrant_integration -v 2>&1 | tee -a test_results/qdrant_integration_tests.log
57+
58+
- name: Upload Qdrant logs
59+
if: always()
60+
uses: actions/upload-artifact@v4
61+
with:
62+
name: qdrant-logs
63+
path: |
64+
test_results/

aios/config/config.yaml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ llms:
5454
# backend: "huggingface"
5555
# max_gpu_memory: {0: "48GB"} # GPU memory allocation
5656
# eval_device: "cuda:0" # Device for model evaluation
57-
57+
5858
# vLLM Models
5959
# To use vllm as backend, you need to install vllm and run the vllm server https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
6060
# An example command to run the vllm server is:
@@ -72,10 +72,18 @@ llms:
7272

7373
memory:
7474
log_mode: "console" # choose from [console, file]
75-
75+
7676
storage:
7777
root_dir: "root"
7878
use_vector_db: true
79+
# vector DB backend: chroma | qdrant
80+
vector_db_backend: "chroma"
81+
# Qdrant connection (used when backend == qdrant). Can also use env vars QDRANT_HOST, QDRANT_PORT, QDRANT_API_KEY
82+
qdrant_host: "localhost"
83+
qdrant_port: 6333
84+
qdrant_api_key: ""
85+
# Embedding model for Qdrant fastembed integration
86+
qdrant_model_name: "sentence-transformers/all-MiniLM-L6-v2"
7987

8088
tool:
8189
mcp_server_script_path: "aios/tool/mcp_server.py"
@@ -85,8 +93,8 @@ scheduler:
8593

8694
agent_factory:
8795
log_mode: "console" # choose from [console, file]
88-
max_workers: 64
89-
96+
max_workers: 64
97+
9098
server:
9199
host: "localhost"
92100
port: 8000

aios/llm_core/routing.py

Lines changed: 100 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from enum import Enum
21
from typing import List, Dict, Any
32
import chromadb
43
from chromadb.utils import embedding_functions
4+
from qdrant_client import QdrantClient, models
5+
from aios.config.config_manager import config as global_config
56
import json
67
import numpy as np
7-
from typing import List, Dict, Any
88
from tqdm import tqdm
99
from collections import defaultdict
1010

11-
import json
11+
import uuid
1212

1313
from threading import Lock
1414

@@ -52,8 +52,8 @@ class RouterStrategy:
5252

5353
class SequentialRouting:
5454
"""
55-
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56-
It iterates through a list of selected language models and returns their corresponding index based on
55+
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56+
It iterates through a list of selected language models and returns their corresponding index based on
5757
the request count.
5858
5959
This strategy ensures that multiple models are utilized in sequence, distributing queries evenly across the available configurations.
@@ -98,24 +98,24 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
9898
"""
9999
# current = self.selected_llms[self.idx]
100100
model_idxs = []
101-
101+
102102
available_models = [llm.name for llm in self.llm_configs]
103-
103+
104104
n_queries = len(queries)
105-
105+
106106
for i in range(n_queries):
107107
selected_llm_list = selected_llm_lists[i]
108-
108+
109109
if not selected_llm_list or len(selected_llm_list) == 0:
110110
model_idxs.append(0)
111111
continue
112-
112+
113113
model_idx = -1
114114
for selected_llm in selected_llm_list:
115115
if selected_llm["name"] in available_models:
116116
model_idx = available_models.index(selected_llm["name"])
117117
break
118-
118+
119119
model_idxs.append(model_idx)
120120

121121
return model_idxs
@@ -138,7 +138,7 @@ def get_token_lengths(queries: List[List[Dict[str, Any]]]):
138138
return [token_counter(model="gpt-4o-mini", messages=query) for query in queries]
139139

140140
def messages_to_query(messages: List[Dict[str, str]],
141-
strategy: str = "last_user") -> str:
141+
strategy: str = "last_user") -> str:
142142
"""
143143
Convert OpenAI ChatCompletion-style messages into a single query string.
144144
strategy:
@@ -201,43 +201,74 @@ def __init__(self,
201201
model_name: str = "all-MiniLM-L6-v2",
202202
persist_directory: str = "llm_router",
203203
bootstrap_url: str | None = None):
204+
storage_cfg = global_config.get_storage_config() or {}
205+
backend = (os.environ.get("VECTOR_DB_BACKEND") or storage_cfg.get("vector_db_backend") or "chroma").lower()
206+
204207
self._persist_root = os.path.join(os.path.dirname(__file__), persist_directory)
205208
os.makedirs(self._persist_root, exist_ok=True)
206209

207-
self.client = chromadb.PersistentClient(path=self._persist_root)
208-
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
210+
self.backend = backend
211+
if backend == "qdrant":
212+
host = storage_cfg.get("qdrant_host", os.environ.get("QDRANT_HOST", "localhost"))
213+
port = int(storage_cfg.get("qdrant_port", os.environ.get("QDRANT_PORT", 6333)))
214+
api_key = storage_cfg.get("qdrant_api_key", os.environ.get("QDRANT_API_KEY"))
215+
self.qdrant = QdrantClient(host=host, port=port, api_key=api_key)
216+
self.model_name = storage_cfg.get("qdrant_model_name") or os.environ.get("QDRANT_EMBEDDING_MODEL", model_name)
217+
self.collection_name = "historical_queries"
218+
self._ensure_qdrant_collection()
219+
self.client = None
220+
self.embedding_function = None
221+
self.collection = None
222+
else:
223+
self.client = chromadb.PersistentClient(path=self._persist_root)
224+
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
225+
self.collection = self._get_or_create_collection("historical_queries")
226+
self.qdrant = None
227+
self.model_name = model_name
228+
self.collection_name = "historical_queries"
209229

210-
# Always create/get collections up‑front so we can inspect counts.
211-
# self.train_collection = self._get_or_create_collection("train_queries")
212-
# self.val_collection = self._get_or_create_collection("val_queries")
213-
# self.test_collection = self._get_or_create_collection("test_queries")
214-
self.collection = self._get_or_create_collection("historical_queries")
215-
216230
# If DB is empty and we have a bootstrap URL – populate it.
217-
if bootstrap_url and self.collection.count() == 0:
218-
self._bootstrap_from_drive(bootstrap_url)
219-
220-
# .................................................................
221-
# Chroma helpers
222-
# .................................................................
231+
if bootstrap_url:
232+
if backend == "qdrant":
233+
count = self._qdrant_count()
234+
if count == 0:
235+
self._bootstrap_from_drive(bootstrap_url)
236+
else:
237+
if self.collection and self.collection.count() == 0:
238+
self._bootstrap_from_drive(bootstrap_url)
223239

224240
def _get_or_create_collection(self, name: str):
241+
if self.backend == "qdrant":
242+
return None
243+
if self.client is None:
244+
return None
225245
try:
226246
return self.client.get_collection(name=name, embedding_function=self.embedding_function)
227247
except Exception:
228248
return self.client.create_collection(name=name, embedding_function=self.embedding_function)
229249

230-
# .................................................................
231-
# Bootstrap logic – download + ingest
232-
# .................................................................
250+
def _ensure_qdrant_collection(self):
251+
if not self.qdrant.collection_exists(self.collection_name):
252+
dim = self.qdrant.get_embedding_size(self.model_name)
253+
self.qdrant.create_collection(
254+
self.collection_name,
255+
vectors_config=models.VectorParams(size=dim, distance=models.Distance.COSINE),
256+
)
257+
258+
def _qdrant_count(self) -> int:
259+
try:
260+
count = self.qdrant.count(self.collection_name).count
261+
return count
262+
except Exception:
263+
return 0
233264

234265
def _bootstrap_from_drive(self, url_or_id: str):
235266
print("\n[SmartRouting] Bootstrapping ChromaDB from Google Drive…\n")
236267

237268
with tempfile.TemporaryDirectory() as tmp:
238269
# NB: gdown accepts both share links and raw IDs.
239270
local_path = os.path.join(tmp, "bootstrap.json")
240-
271+
241272
gdown.download(url_or_id, local_path, quiet=False, fuzzy=True)
242273

243274
# Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
@@ -249,20 +280,20 @@ def _bootstrap_from_drive(self, url_or_id: str):
249280

250281
print("[SmartRouting] Bootstrap complete – collections populated.\n")
251282

252-
# .................................................................
253-
# Public data API
254-
# .................................................................
255-
256283
def add_data(self, data: List[Dict[str, Any]]):
257-
collection = self.collection
258-
queries, metadatas, ids = [], [], []
284+
if self.backend == "qdrant":
285+
queries, metadatas, ids = [], [], []
286+
else:
287+
collection = self.collection
288+
queries, metadatas, ids = [], [], []
289+
259290
correct_count = total_count = 0
260291

261-
for idx, item in enumerate(tqdm(data, desc=f"Ingesting historical queries")):
292+
for idx, item in enumerate(tqdm(data, desc="Ingesting historical queries")):
262293
query = item["query"]
263294
model_metadatas = item["outputs"]
264295
for model_metadata in model_metadatas:
265-
model_metadata.pop("prediction")
296+
model_metadata.pop("prediction", None)
266297
meta = {
267298
"input_token_length": item["input_token_length"],
268299
"models": json.dumps(model_metadatas), # store raw list
@@ -275,15 +306,40 @@ def add_data(self, data: List[Dict[str, Any]]):
275306
metadatas.append(meta)
276307
ids.append(f"{idx}")
277308

278-
collection.add(documents=queries, metadatas=metadatas, ids=ids)
279-
print(f"[SmartRouting]: {total_count} historical queries ingested.")
309+
if self.backend == "qdrant":
310+
docs = [models.Document(text=q, model=self.model_name) for q in queries]
311+
if docs:
312+
# Deterministic UUIDv5 ids; store original in payload
313+
q_ids = [str(uuid.uuid5(uuid.NAMESPACE_URL, i)) for i in ids]
314+
for i, meta in enumerate(metadatas):
315+
meta["original_id"] = ids[i]
316+
self.qdrant.upload_collection(
317+
collection_name=self.collection_name,
318+
vectors=docs,
319+
ids=q_ids,
320+
payload=metadatas,
321+
)
322+
else:
323+
if queries and metadatas and ids: # Only add if we have data
324+
collection.add(documents=queries, metadatas=metadatas, ids=ids)
325+
print(f"[SmartRouting]: {total_count} historical queries ingested.")
280326

281-
# ..................................................................
282327
def query_similar(self, query: str | List[str], n_results: int = 16):
328+
if self.backend == "qdrant":
329+
qtext = query if isinstance(query, str) else query[0]
330+
results = self.qdrant.query_points(
331+
collection_name=self.collection_name,
332+
query=models.Document(text=qtext, model=self.model_name),
333+
limit=n_results,
334+
).points
335+
return {
336+
"ids": [[(r.payload or {}).get("original_id", str(r.id)) for r in results]],
337+
"metadatas": [[(r.payload or {}) for r in results]],
338+
"documents": [[""] * len(results)],
339+
}
283340
collection = self.collection
284341
return collection.query(query_texts=query if isinstance(query, list) else [query], n_results=n_results)
285342

286-
# ..................................................................
287343
def predict(self, query: str | List[str], model_configs: List[Dict[str, Any]], n_similar: int = 16):
288344
similar = self.query_similar(query, n_results=n_similar)
289345
perf_mat, len_mat = [], []
@@ -355,7 +411,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
355411

356412
input_lens = get_token_lengths(queries)
357413
chosen_indices: list[int] = []
358-
414+
359415
converted_queries = [messages_to_query(query) for query in queries]
360416

361417
for q, q_len, candidate_cfgs in zip(converted_queries, input_lens, selected_llm_lists):
@@ -376,7 +432,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
376432

377433
# Map back to global llm_configs index
378434
sel_name = candidate_cfgs[sel_local_idx]["name"]
379-
435+
380436
sel_idx = self.available_models.index(sel_name)
381437
chosen_indices.append(sel_idx)
382438

0 commit comments

Comments
 (0)