1- from enum import Enum
21from typing import List , Dict , Any
32import chromadb
43from chromadb .utils import embedding_functions
4+ from qdrant_client import QdrantClient , models
5+ from aios .config .config_manager import config as global_config
56import json
67import numpy as np
7- from typing import List , Dict , Any
88from tqdm import tqdm
99from collections import defaultdict
1010
11- import json
11+ import uuid
1212
1313from threading import Lock
1414
@@ -52,8 +52,8 @@ class RouterStrategy:
5252
5353class SequentialRouting :
5454 """
55- The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56- It iterates through a list of selected language models and returns their corresponding index based on
55+ The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56+ It iterates through a list of selected language models and returns their corresponding index based on
5757 the request count.
5858
5959 This strategy ensures that multiple models are utilized in sequence, distributing queries evenly across the available configurations.
@@ -98,24 +98,24 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
9898 """
9999 # current = self.selected_llms[self.idx]
100100 model_idxs = []
101-
101+
102102 available_models = [llm .name for llm in self .llm_configs ]
103-
103+
104104 n_queries = len (queries )
105-
105+
106106 for i in range (n_queries ):
107107 selected_llm_list = selected_llm_lists [i ]
108-
108+
109109 if not selected_llm_list or len (selected_llm_list ) == 0 :
110110 model_idxs .append (0 )
111111 continue
112-
112+
113113 model_idx = - 1
114114 for selected_llm in selected_llm_list :
115115 if selected_llm ["name" ] in available_models :
116116 model_idx = available_models .index (selected_llm ["name" ])
117117 break
118-
118+
119119 model_idxs .append (model_idx )
120120
121121 return model_idxs
@@ -138,7 +138,7 @@ def get_token_lengths(queries: List[List[Dict[str, Any]]]):
138138 return [token_counter (model = "gpt-4o-mini" , messages = query ) for query in queries ]
139139
140140def messages_to_query (messages : List [Dict [str , str ]],
141- strategy : str = "last_user" ) -> str :
141+ strategy : str = "last_user" ) -> str :
142142 """
143143 Convert OpenAI ChatCompletion-style messages into a single query string.
144144 strategy:
@@ -201,43 +201,66 @@ def __init__(self,
201201 model_name : str = "all-MiniLM-L6-v2" ,
202202 persist_directory : str = "llm_router" ,
203203 bootstrap_url : str | None = None ):
204+ storage_cfg = global_config .get_storage_config () or {}
205+ backend = (storage_cfg .get ("vector_db_backend" ) or os .environ .get ("VECTOR_DB_BACKEND" ) or "chroma" ).lower ()
206+
204207 self ._persist_root = os .path .join (os .path .dirname (__file__ ), persist_directory )
205208 os .makedirs (self ._persist_root , exist_ok = True )
206209
207- self .client = chromadb .PersistentClient (path = self ._persist_root )
208- self .embedding_function = embedding_functions .DefaultEmbeddingFunction ()
210+ self .backend = backend
211+ if backend == "qdrant" :
212+ host = storage_cfg .get ("qdrant_host" , os .environ .get ("QDRANT_HOST" , "localhost" ))
213+ port = int (storage_cfg .get ("qdrant_port" , os .environ .get ("QDRANT_PORT" , 6333 )))
214+ api_key = storage_cfg .get ("qdrant_api_key" , os .environ .get ("QDRANT_API_KEY" ))
215+ self .qdrant = QdrantClient (host = host , port = port , api_key = api_key )
216+ self .model_name = storage_cfg .get ("qdrant_model_name" ) or os .environ .get ("QDRANT_EMBEDDING_MODEL" , model_name )
217+ self .collection_name = "historical_queries"
218+ self ._ensure_qdrant_collection ()
219+ else :
220+ self .client = chromadb .PersistentClient (path = self ._persist_root )
221+ self .embedding_function = embedding_functions .DefaultEmbeddingFunction ()
222+ self .collection = self ._get_or_create_collection ("historical_queries" )
209223
210- # Always create/get collections up‑front so we can inspect counts.
211- # self.train_collection = self._get_or_create_collection("train_queries")
212- # self.val_collection = self._get_or_create_collection("val_queries")
213- # self.test_collection = self._get_or_create_collection("test_queries")
214- self .collection = self ._get_or_create_collection ("historical_queries" )
215-
216224 # If DB is empty and we have a bootstrap URL – populate it.
217- if bootstrap_url and self .collection .count () == 0 :
218- self ._bootstrap_from_drive (bootstrap_url )
219-
220- # .................................................................
221- # Chroma helpers
222- # .................................................................
225+ if bootstrap_url :
226+ if backend == "qdrant" :
227+ count = self ._qdrant_count ()
228+ if count == 0 :
229+ self ._bootstrap_from_drive (bootstrap_url )
230+ else :
231+ if self .collection .count () == 0 :
232+ self ._bootstrap_from_drive (bootstrap_url )
223233
224234 def _get_or_create_collection (self , name : str ):
235+ if self .backend == "qdrant" :
236+ return None
225237 try :
226238 return self .client .get_collection (name = name , embedding_function = self .embedding_function )
227239 except Exception :
228240 return self .client .create_collection (name = name , embedding_function = self .embedding_function )
229241
230- # .................................................................
231- # Bootstrap logic – download + ingest
232- # .................................................................
242+ def _ensure_qdrant_collection (self ):
243+ if not self .qdrant .collection_exists (self .collection_name ):
244+ dim = self .qdrant .get_embedding_size (self .model_name )
245+ self .qdrant .create_collection (
246+ self .collection_name ,
247+ vectors_config = models .VectorParams (size = dim , distance = models .Distance .COSINE ),
248+ )
249+
250+ def _qdrant_count (self ) -> int :
251+ try :
252+ count = self .qdrant .count (self .collection_name ).count
253+ return count
254+ except Exception :
255+ return 0
233256
234257 def _bootstrap_from_drive (self , url_or_id : str ):
235258 print ("\n [SmartRouting] Bootstrapping ChromaDB from Google Drive…\n " )
236259
237260 with tempfile .TemporaryDirectory () as tmp :
238261 # NB: gdown accepts both share links and raw IDs.
239262 local_path = os .path .join (tmp , "bootstrap.json" )
240-
263+
241264 gdown .download (url_or_id , local_path , quiet = False , fuzzy = True )
242265
243266 # Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
@@ -249,16 +272,16 @@ def _bootstrap_from_drive(self, url_or_id: str):
249272
250273 print ("[SmartRouting] Bootstrap complete – collections populated.\n " )
251274
252- # .................................................................
253- # Public data API
254- # .................................................................
255-
256275 def add_data (self , data : List [Dict [str , Any ]]):
257- collection = self .collection
258- queries , metadatas , ids = [], [], []
276+ if self .backend == "qdrant" :
277+ queries , metadatas , ids = [], [], []
278+ else :
279+ collection = self .collection
280+ queries , metadatas , ids = [], [], []
281+
259282 correct_count = total_count = 0
260283
261- for idx , item in enumerate (tqdm (data , desc = f "Ingesting historical queries" )):
284+ for idx , item in enumerate (tqdm (data , desc = "Ingesting historical queries" )):
262285 query = item ["query" ]
263286 model_metadatas = item ["outputs" ]
264287 for model_metadata in model_metadatas :
@@ -275,15 +298,39 @@ def add_data(self, data: List[Dict[str, Any]]):
275298 metadatas .append (meta )
276299 ids .append (f"{ idx } " )
277300
278- collection .add (documents = queries , metadatas = metadatas , ids = ids )
279- print (f"[SmartRouting]: { total_count } historical queries ingested." )
301+ if self .backend == "qdrant" :
302+ docs = [models .Document (text = q , model = self .model_name ) for q in queries ]
303+ if docs :
304+ # Deterministic UUIDv5 ids; store original in payload
305+ q_ids = [str (uuid .uuid5 (uuid .NAMESPACE_URL , i )) for i in ids ]
306+ for i , meta in enumerate (metadatas ):
307+ meta ["original_id" ] = ids [i ]
308+ self .qdrant .upload_collection (
309+ collection_name = self .collection_name ,
310+ vectors = docs ,
311+ ids = q_ids ,
312+ payload = metadatas ,
313+ )
314+ else :
315+ collection .add (documents = queries , metadatas = metadatas , ids = ids )
316+ print (f"[SmartRouting]: { total_count } historical queries ingested." )
280317
281- # ..................................................................
282318 def query_similar (self , query : str | List [str ], n_results : int = 16 ):
319+ if self .backend == "qdrant" :
320+ qtext = query if isinstance (query , str ) else query [0 ]
321+ results = self .qdrant .query_points (
322+ collection_name = self .collection_name ,
323+ query = models .Document (text = qtext , model = self .model_name ),
324+ limit = n_results ,
325+ ).points
326+ return {
327+ "ids" : [[(r .payload or {}).get ("original_id" , str (r .id )) for r in results ]],
328+ "metadatas" : [[(r .payload or {}) for r in results ]],
329+ "documents" : [["" ] * len (results )],
330+ }
283331 collection = self .collection
284332 return collection .query (query_texts = query if isinstance (query , list ) else [query ], n_results = n_results )
285333
286- # ..................................................................
287334 def predict (self , query : str | List [str ], model_configs : List [Dict [str , Any ]], n_similar : int = 16 ):
288335 similar = self .query_similar (query , n_results = n_similar )
289336 perf_mat , len_mat = [], []
@@ -355,7 +402,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
355402
356403 input_lens = get_token_lengths (queries )
357404 chosen_indices : list [int ] = []
358-
405+
359406 converted_queries = [messages_to_query (query ) for query in queries ]
360407
361408 for q , q_len , candidate_cfgs in zip (converted_queries , input_lens , selected_llm_lists ):
@@ -376,7 +423,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
376423
377424 # Map back to global llm_configs index
378425 sel_name = candidate_cfgs [sel_local_idx ]["name" ]
379-
426+
380427 sel_idx = self .available_models .index (sel_name )
381428 chosen_indices .append (sel_idx )
382429
0 commit comments