Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions .github/workflows/test-ollama.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,23 @@ jobs:
echo "Waiting for Kernel… ($i/${KERNEL_MAX_WAIT})"; sleep 1
done

- name: Run tests
- name: Run Ollama tests
run: |
source .venv/bin/activate
mkdir -p test_results
mapfile -t TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py")
if [ "${#TESTS[@]}" -eq 0 ]; then

# Run Ollama tests
mapfile -t OLLAMA_TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py")
if [ "${#OLLAMA_TESTS[@]}" -eq 0 ]; then
echo "⚠️ No llm/ollama tests found – skipping."
exit 0
else
for t in "${OLLAMA_TESTS[@]}"; do
echo "▶️ Running Ollama test: $t"
python -m unittest "$t" 2>&1 | tee -a test_results/ollama_tests.log
echo "----------------------------------------"
done
fi
for t in "${TESTS[@]}"; do
echo "▶️ Running $t"
python -m unittest "$t" 2>&1 | tee -a test_results/ollama_tests.log
echo "----------------------------------------"
done


- name: Upload logs
if: always()
Expand Down
55 changes: 55 additions & 0 deletions .github/workflows/test-qdrant.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: Test Qdrant Integration

on:
pull_request:
branches: [ "main" ]
push:
branches: [ "main" ]

permissions:
contents: read
actions: write

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
test-qdrant:
runs-on: ubuntu-latest

services:
qdrant:
image: qdrant/qdrant:latest
ports:
- 6333:6333

env:
VECTOR_DB_BACKEND: qdrant
QDRANT_HOST: localhost
QDRANT_PORT: 6333
QDRANT_EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2

steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install uv
uv venv
source .venv/bin/activate
uv pip install -r requirements.txt

- name: Run Qdrant integration tests
run: |
source .venv/bin/activate
mkdir -p test_results

echo "Running Qdrant integration tests"
python -m unittest tests.modules.test_qdrant_integration

16 changes: 12 additions & 4 deletions aios/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ llms:
# backend: "huggingface"
# max_gpu_memory: {0: "48GB"} # GPU memory allocation
# eval_device: "cuda:0" # Device for model evaluation

# vLLM Models
# To use vllm as backend, you need to install vllm and run the vllm server https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
# An example command to run the vllm server is:
Expand All @@ -72,10 +72,18 @@ llms:

memory:
log_mode: "console" # choose from [console, file]

storage:
root_dir: "root"
use_vector_db: true
# vector DB backend: chroma | qdrant
vector_db_backend: "chroma"
# Qdrant connection (used when backend == qdrant). Can also use env vars QDRANT_HOST, QDRANT_PORT, QDRANT_API_KEY
qdrant_host: "localhost"
qdrant_port: 6333
qdrant_api_key: ""
# Embedding model for Qdrant fastembed integration
qdrant_model_name: "sentence-transformers/all-MiniLM-L6-v2"

tool:
mcp_server_script_path: "aios/tool/mcp_server.py"
Expand All @@ -85,8 +93,8 @@ scheduler:

agent_factory:
log_mode: "console" # choose from [console, file]
max_workers: 64
max_workers: 64

server:
host: "localhost"
port: 8000
144 changes: 100 additions & 44 deletions aios/llm_core/routing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from enum import Enum
from typing import List, Dict, Any
import chromadb
from chromadb.utils import embedding_functions
from qdrant_client import QdrantClient, models
from aios.config.config_manager import config as global_config
import json
import numpy as np
from typing import List, Dict, Any
from tqdm import tqdm
from collections import defaultdict

import json
import uuid

from threading import Lock

Expand Down Expand Up @@ -52,8 +52,8 @@ class RouterStrategy:

class SequentialRouting:
"""
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
It iterates through a list of selected language models and returns their corresponding index based on
The SequentialRouting class implements a round-robin selection strategy for load-balancing LLM requests.
It iterates through a list of selected language models and returns their corresponding index based on
the request count.

This strategy ensures that multiple models are utilized in sequence, distributing queries evenly across the available configurations.
Expand Down Expand Up @@ -98,24 +98,24 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries
"""
# current = self.selected_llms[self.idx]
model_idxs = []

available_models = [llm.name for llm in self.llm_configs]

n_queries = len(queries)

for i in range(n_queries):
selected_llm_list = selected_llm_lists[i]

if not selected_llm_list or len(selected_llm_list) == 0:
model_idxs.append(0)
continue

model_idx = -1
for selected_llm in selected_llm_list:
if selected_llm["name"] in available_models:
model_idx = available_models.index(selected_llm["name"])
break

model_idxs.append(model_idx)

return model_idxs
Expand All @@ -138,7 +138,7 @@ def get_token_lengths(queries: List[List[Dict[str, Any]]]):
return [token_counter(model="gpt-4o-mini", messages=query) for query in queries]

def messages_to_query(messages: List[Dict[str, str]],
strategy: str = "last_user") -> str:
strategy: str = "last_user") -> str:
"""
Convert OpenAI ChatCompletion-style messages into a single query string.
strategy:
Expand Down Expand Up @@ -201,43 +201,74 @@ def __init__(self,
model_name: str = "all-MiniLM-L6-v2",
persist_directory: str = "llm_router",
bootstrap_url: str | None = None):
storage_cfg = global_config.get_storage_config() or {}
backend = (os.environ.get("VECTOR_DB_BACKEND") or storage_cfg.get("vector_db_backend") or "chroma").lower()

self._persist_root = os.path.join(os.path.dirname(__file__), persist_directory)
os.makedirs(self._persist_root, exist_ok=True)

self.client = chromadb.PersistentClient(path=self._persist_root)
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
self.backend = backend
if backend == "qdrant":
host = storage_cfg.get("qdrant_host", os.environ.get("QDRANT_HOST", "localhost"))
port = int(storage_cfg.get("qdrant_port", os.environ.get("QDRANT_PORT", 6333)))
api_key = storage_cfg.get("qdrant_api_key", os.environ.get("QDRANT_API_KEY"))
self.qdrant = QdrantClient(host=host, port=port, api_key=api_key)
self.model_name = storage_cfg.get("qdrant_model_name") or os.environ.get("QDRANT_EMBEDDING_MODEL", model_name)
self.collection_name = "historical_queries"
self._ensure_qdrant_collection()
self.client = None
self.embedding_function = None
self.collection = None
else:
self.client = chromadb.PersistentClient(path=self._persist_root)
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
self.collection = self._get_or_create_collection("historical_queries")
self.qdrant = None
self.model_name = model_name
self.collection_name = "historical_queries"

# Always create/get collections up‑front so we can inspect counts.
# self.train_collection = self._get_or_create_collection("train_queries")
# self.val_collection = self._get_or_create_collection("val_queries")
# self.test_collection = self._get_or_create_collection("test_queries")
self.collection = self._get_or_create_collection("historical_queries")

# If DB is empty and we have a bootstrap URL – populate it.
if bootstrap_url and self.collection.count() == 0:
self._bootstrap_from_drive(bootstrap_url)

# .................................................................
# Chroma helpers
# .................................................................
if bootstrap_url:
if backend == "qdrant":
count = self._qdrant_count()
if count == 0:
self._bootstrap_from_drive(bootstrap_url)
else:
if self.collection and self.collection.count() == 0:
self._bootstrap_from_drive(bootstrap_url)

def _get_or_create_collection(self, name: str):
if self.backend == "qdrant":
return None
if self.client is None:
return None
try:
return self.client.get_collection(name=name, embedding_function=self.embedding_function)
except Exception:
return self.client.create_collection(name=name, embedding_function=self.embedding_function)

# .................................................................
# Bootstrap logic – download + ingest
# .................................................................
def _ensure_qdrant_collection(self):
if not self.qdrant.collection_exists(self.collection_name):
dim = self.qdrant.get_embedding_size(self.model_name)
self.qdrant.create_collection(
self.collection_name,
vectors_config=models.VectorParams(size=dim, distance=models.Distance.COSINE),
)

def _qdrant_count(self) -> int:
try:
count = self.qdrant.count(self.collection_name).count
return count
except Exception:
return 0

def _bootstrap_from_drive(self, url_or_id: str):
print("\n[SmartRouting] Bootstrapping ChromaDB from Google Drive…\n")

with tempfile.TemporaryDirectory() as tmp:
# NB: gdown accepts both share links and raw IDs.
local_path = os.path.join(tmp, "bootstrap.json")

gdown.download(url_or_id, local_path, quiet=False, fuzzy=True)

# Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
Expand All @@ -249,20 +280,20 @@ def _bootstrap_from_drive(self, url_or_id: str):

print("[SmartRouting] Bootstrap complete – collections populated.\n")

# .................................................................
# Public data API
# .................................................................

def add_data(self, data: List[Dict[str, Any]]):
collection = self.collection
queries, metadatas, ids = [], [], []
if self.backend == "qdrant":
queries, metadatas, ids = [], [], []
else:
collection = self.collection
queries, metadatas, ids = [], [], []

correct_count = total_count = 0

for idx, item in enumerate(tqdm(data, desc=f"Ingesting historical queries")):
for idx, item in enumerate(tqdm(data, desc="Ingesting historical queries")):
query = item["query"]
model_metadatas = item["outputs"]
for model_metadata in model_metadatas:
model_metadata.pop("prediction")
model_metadata.pop("prediction", None)
meta = {
"input_token_length": item["input_token_length"],
"models": json.dumps(model_metadatas), # store raw list
Expand All @@ -275,15 +306,40 @@ def add_data(self, data: List[Dict[str, Any]]):
metadatas.append(meta)
ids.append(f"{idx}")

collection.add(documents=queries, metadatas=metadatas, ids=ids)
print(f"[SmartRouting]: {total_count} historical queries ingested.")
if self.backend == "qdrant":
docs = [models.Document(text=q, model=self.model_name) for q in queries]
if docs:
# Deterministic UUIDv5 ids; store original in payload
q_ids = [str(uuid.uuid5(uuid.NAMESPACE_URL, i)) for i in ids]
for i, meta in enumerate(metadatas):
meta["original_id"] = ids[i]
self.qdrant.upload_collection(
collection_name=self.collection_name,
vectors=docs,
ids=q_ids,
payload=metadatas,
)
else:
if queries and metadatas and ids: # Only add if we have data
collection.add(documents=queries, metadatas=metadatas, ids=ids)
print(f"[SmartRouting]: {total_count} historical queries ingested.")

# ..................................................................
def query_similar(self, query: str | List[str], n_results: int = 16):
if self.backend == "qdrant":
qtext = query if isinstance(query, str) else query[0]
results = self.qdrant.query_points(
collection_name=self.collection_name,
query=models.Document(text=qtext, model=self.model_name),
limit=n_results,
).points
return {
"ids": [[(r.payload or {}).get("original_id", str(r.id)) for r in results]],
"metadatas": [[(r.payload or {}) for r in results]],
"documents": [[""] * len(results)],
}
collection = self.collection
return collection.query(query_texts=query if isinstance(query, list) else [query], n_results=n_results)

# ..................................................................
def predict(self, query: str | List[str], model_configs: List[Dict[str, Any]], n_similar: int = 16):
similar = self.query_similar(query, n_results=n_similar)
perf_mat, len_mat = [], []
Expand Down Expand Up @@ -355,7 +411,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries

input_lens = get_token_lengths(queries)
chosen_indices: list[int] = []

converted_queries = [messages_to_query(query) for query in queries]

for q, q_len, candidate_cfgs in zip(converted_queries, input_lens, selected_llm_lists):
Expand All @@ -376,7 +432,7 @@ def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries

# Map back to global llm_configs index
sel_name = candidate_cfgs[sel_local_idx]["name"]

sel_idx = self.available_models.index(sel_name)
chosen_indices.append(sel_idx)

Expand Down
Loading
Loading