1- from enum import Enum
21from typing import List , Dict , Any
32import chromadb
43from chromadb .utils import embedding_functions
4+ from qdrant_client import QdrantClient , models
5+ from aios .config .config_manager import config as global_config
56import json
67import numpy as np
7- from typing import List , Dict , Any
88from tqdm import tqdm
99from collections import defaultdict
1010
11- import json
11+ import uuid
1212
1313from threading import Lock
1414
@@ -52,8 +52,8 @@ class RouterStrategy:
5252
5353class SequentialRouting :
5454 """
55- The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56- It iterates through a list of selected language models and returns their corresponding index based on
55+ The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
56+ It iterates through a list of selected language models and returns their corresponding index based on
5757 the request count.
5858
5959 This strategy ensures that multiple models are utilized in sequence, distributing queries evenly across the available configurations.
@@ -98,24 +98,24 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
9898 """
9999 # current = self.selected_llms[self.idx]
100100 model_idxs = []
101-
101+
102102 available_models = [llm .name for llm in self .llm_configs ]
103-
103+
104104 n_queries = len (queries )
105-
105+
106106 for i in range (n_queries ):
107107 selected_llm_list = selected_llm_lists [i ]
108-
108+
109109 if not selected_llm_list or len (selected_llm_list ) == 0 :
110110 model_idxs .append (0 )
111111 continue
112-
112+
113113 model_idx = - 1
114114 for selected_llm in selected_llm_list :
115115 if selected_llm ["name" ] in available_models :
116116 model_idx = available_models .index (selected_llm ["name" ])
117117 break
118-
118+
119119 model_idxs .append (model_idx )
120120
121121 return model_idxs
@@ -138,7 +138,7 @@ def get_token_lengths(queries: List[List[Dict[str, Any]]]):
138138 return [token_counter (model = "gpt-4o-mini" , messages = query ) for query in queries ]
139139
140140def messages_to_query (messages : List [Dict [str , str ]],
141- strategy : str = "last_user" ) -> str :
141+ strategy : str = "last_user" ) -> str :
142142 """
143143 Convert OpenAI ChatCompletion-style messages into a single query string.
144144 strategy:
@@ -201,43 +201,74 @@ def __init__(self,
201201 model_name : str = "all-MiniLM-L6-v2" ,
202202 persist_directory : str = "llm_router" ,
203203 bootstrap_url : str | None = None ):
204+ storage_cfg = global_config .get_storage_config () or {}
205+ backend = (os .environ .get ("VECTOR_DB_BACKEND" ) or storage_cfg .get ("vector_db_backend" ) or "chroma" ).lower ()
206+
204207 self ._persist_root = os .path .join (os .path .dirname (__file__ ), persist_directory )
205208 os .makedirs (self ._persist_root , exist_ok = True )
206209
207- self .client = chromadb .PersistentClient (path = self ._persist_root )
208- self .embedding_function = embedding_functions .DefaultEmbeddingFunction ()
210+ self .backend = backend
211+ if backend == "qdrant" :
212+ host = storage_cfg .get ("qdrant_host" , os .environ .get ("QDRANT_HOST" , "localhost" ))
213+ port = int (storage_cfg .get ("qdrant_port" , os .environ .get ("QDRANT_PORT" , 6333 )))
214+ api_key = storage_cfg .get ("qdrant_api_key" , os .environ .get ("QDRANT_API_KEY" ))
215+ self .qdrant = QdrantClient (host = host , port = port , api_key = api_key )
216+ self .model_name = storage_cfg .get ("qdrant_model_name" ) or os .environ .get ("QDRANT_EMBEDDING_MODEL" , model_name )
217+ self .collection_name = "historical_queries"
218+ self ._ensure_qdrant_collection ()
219+ self .client = None
220+ self .embedding_function = None
221+ self .collection = None
222+ else :
223+ self .client = chromadb .PersistentClient (path = self ._persist_root )
224+ self .embedding_function = embedding_functions .DefaultEmbeddingFunction ()
225+ self .collection = self ._get_or_create_collection ("historical_queries" )
226+ self .qdrant = None
227+ self .model_name = model_name
228+ self .collection_name = "historical_queries"
209229
210- # Always create/get collections up‑front so we can inspect counts.
211- # self.train_collection = self._get_or_create_collection("train_queries")
212- # self.val_collection = self._get_or_create_collection("val_queries")
213- # self.test_collection = self._get_or_create_collection("test_queries")
214- self .collection = self ._get_or_create_collection ("historical_queries" )
215-
216230 # If DB is empty and we have a bootstrap URL – populate it.
217- if bootstrap_url and self .collection .count () == 0 :
218- self ._bootstrap_from_drive (bootstrap_url )
219-
220- # .................................................................
221- # Chroma helpers
222- # .................................................................
231+ if bootstrap_url :
232+ if backend == "qdrant" :
233+ count = self ._qdrant_count ()
234+ if count == 0 :
235+ self ._bootstrap_from_drive (bootstrap_url )
236+ else :
237+ if self .collection and self .collection .count () == 0 :
238+ self ._bootstrap_from_drive (bootstrap_url )
223239
224240 def _get_or_create_collection (self , name : str ):
241+ if self .backend == "qdrant" :
242+ return None
243+ if self .client is None :
244+ return None
225245 try :
226246 return self .client .get_collection (name = name , embedding_function = self .embedding_function )
227247 except Exception :
228248 return self .client .create_collection (name = name , embedding_function = self .embedding_function )
229249
230- # .................................................................
231- # Bootstrap logic – download + ingest
232- # .................................................................
250+ def _ensure_qdrant_collection (self ):
251+ if not self .qdrant .collection_exists (self .collection_name ):
252+ dim = self .qdrant .get_embedding_size (self .model_name )
253+ self .qdrant .create_collection (
254+ self .collection_name ,
255+ vectors_config = models .VectorParams (size = dim , distance = models .Distance .COSINE ),
256+ )
257+
258+ def _qdrant_count (self ) -> int :
259+ try :
260+ count = self .qdrant .count (self .collection_name ).count
261+ return count
262+ except Exception :
263+ return 0
233264
234265 def _bootstrap_from_drive (self , url_or_id : str ):
235266 print ("\n [SmartRouting] Bootstrapping ChromaDB from Google Drive…\n " )
236267
237268 with tempfile .TemporaryDirectory () as tmp :
238269 # NB: gdown accepts both share links and raw IDs.
239270 local_path = os .path .join (tmp , "bootstrap.json" )
240-
271+
241272 gdown .download (url_or_id , local_path , quiet = False , fuzzy = True )
242273
243274 # Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
@@ -249,20 +280,20 @@ def _bootstrap_from_drive(self, url_or_id: str):
249280
250281 print ("[SmartRouting] Bootstrap complete – collections populated.\n " )
251282
252- # .................................................................
253- # Public data API
254- # .................................................................
255-
256283 def add_data (self , data : List [Dict [str , Any ]]):
257- collection = self .collection
258- queries , metadatas , ids = [], [], []
284+ if self .backend == "qdrant" :
285+ queries , metadatas , ids = [], [], []
286+ else :
287+ collection = self .collection
288+ queries , metadatas , ids = [], [], []
289+
259290 correct_count = total_count = 0
260291
261- for idx , item in enumerate (tqdm (data , desc = f "Ingesting historical queries" )):
292+ for idx , item in enumerate (tqdm (data , desc = "Ingesting historical queries" )):
262293 query = item ["query" ]
263294 model_metadatas = item ["outputs" ]
264295 for model_metadata in model_metadatas :
265- model_metadata .pop ("prediction" )
296+ model_metadata .pop ("prediction" , None )
266297 meta = {
267298 "input_token_length" : item ["input_token_length" ],
268299 "models" : json .dumps (model_metadatas ), # store raw list
@@ -275,15 +306,40 @@ def add_data(self, data: List[Dict[str, Any]]):
275306 metadatas .append (meta )
276307 ids .append (f"{ idx } " )
277308
278- collection .add (documents = queries , metadatas = metadatas , ids = ids )
279- print (f"[SmartRouting]: { total_count } historical queries ingested." )
309+ if self .backend == "qdrant" :
310+ docs = [models .Document (text = q , model = self .model_name ) for q in queries ]
311+ if docs :
312+ # Deterministic UUIDv5 ids; store original in payload
313+ q_ids = [str (uuid .uuid5 (uuid .NAMESPACE_URL , i )) for i in ids ]
314+ for i , meta in enumerate (metadatas ):
315+ meta ["original_id" ] = ids [i ]
316+ self .qdrant .upload_collection (
317+ collection_name = self .collection_name ,
318+ vectors = docs ,
319+ ids = q_ids ,
320+ payload = metadatas ,
321+ )
322+ else :
323+ if queries and metadatas and ids : # Only add if we have data
324+ collection .add (documents = queries , metadatas = metadatas , ids = ids )
325+ print (f"[SmartRouting]: { total_count } historical queries ingested." )
280326
281- # ..................................................................
282327 def query_similar (self , query : str | List [str ], n_results : int = 16 ):
328+ if self .backend == "qdrant" :
329+ qtext = query if isinstance (query , str ) else query [0 ]
330+ results = self .qdrant .query_points (
331+ collection_name = self .collection_name ,
332+ query = models .Document (text = qtext , model = self .model_name ),
333+ limit = n_results ,
334+ ).points
335+ return {
336+ "ids" : [[(r .payload or {}).get ("original_id" , str (r .id )) for r in results ]],
337+ "metadatas" : [[(r .payload or {}) for r in results ]],
338+ "documents" : [["" ] * len (results )],
339+ }
283340 collection = self .collection
284341 return collection .query (query_texts = query if isinstance (query , list ) else [query ], n_results = n_results )
285342
286- # ..................................................................
287343 def predict (self , query : str | List [str ], model_configs : List [Dict [str , Any ]], n_similar : int = 16 ):
288344 similar = self .query_similar (query , n_results = n_similar )
289345 perf_mat , len_mat = [], []
@@ -355,7 +411,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
355411
356412 input_lens = get_token_lengths (queries )
357413 chosen_indices : list [int ] = []
358-
414+
359415 converted_queries = [messages_to_query (query ) for query in queries ]
360416
361417 for q , q_len , candidate_cfgs in zip (converted_queries , input_lens , selected_llm_lists ):
@@ -376,7 +432,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
376432
377433 # Map back to global llm_configs index
378434 sel_name = candidate_cfgs [sel_local_idx ]["name" ]
379-
435+
380436 sel_idx = self .available_models .index (sel_name )
381437 chosen_indices .append (sel_idx )
382438
0 commit comments