Skip to content

Commit acf2e75

Browse files
committed
initial implementation
1 parent 59bec83 commit acf2e75

File tree

9 files changed

+1415
-0
lines changed

9 files changed

+1415
-0
lines changed

optillm/plugins/proxy/README.md

Lines changed: 422 additions & 0 deletions
Large diffs are not rendered by default.

optillm/plugins/proxy/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""OptiLLM Proxy Plugin - Load balancing and failover for LLM providers"""
2+
3+
from .config import ProxyConfig
4+
from .client import ProxyClient
5+
from .routing import RouterFactory
6+
from .health import HealthChecker
7+
8+
__all__ = ['ProxyConfig', 'ProxyClient', 'RouterFactory', 'HealthChecker']
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""
2+
Dynamic handler for approaches and plugins - no hardcoding.
3+
"""
4+
import importlib
5+
import importlib.util
6+
import logging
7+
import inspect
8+
from typing import Optional, Tuple, Dict, Any
9+
from pathlib import Path
10+
11+
logger = logging.getLogger(__name__)
12+
13+
class ApproachHandler:
14+
"""Dynamically handles both approaches and plugins"""
15+
16+
def __init__(self):
17+
self._approaches_cache = {}
18+
self._plugins_cache = {}
19+
self._discovered = False
20+
21+
def handle(self, name: str, system_prompt: str, initial_query: str,
22+
client, model: str, request_config: dict = None) -> Optional[Tuple[str, int]]:
23+
"""
24+
Try to handle the given name as an approach or plugin.
25+
Returns None if not found, otherwise returns (response, tokens)
26+
"""
27+
# Lazy discovery
28+
if not self._discovered:
29+
self._discover_handlers()
30+
self._discovered = True
31+
32+
# Check if it's an approach
33+
if name in self._approaches_cache:
34+
logger.info(f"Routing approach '{name}' through proxy")
35+
handler = self._approaches_cache[name]
36+
return self._execute_handler(
37+
handler, system_prompt, initial_query, client, model, request_config
38+
)
39+
40+
# Check if it's a plugin
41+
if name in self._plugins_cache:
42+
logger.info(f"Routing plugin '{name}' through proxy")
43+
handler = self._plugins_cache[name]
44+
return self._execute_handler(
45+
handler, system_prompt, initial_query, client, model, request_config
46+
)
47+
48+
logger.debug(f"'{name}' not recognized as approach or plugin")
49+
return None
50+
51+
def _discover_handlers(self):
52+
"""Discover available approaches and plugins dynamically"""
53+
54+
# Discover approaches
55+
self._discover_approaches()
56+
57+
# Discover plugins
58+
self._discover_plugins()
59+
60+
logger.info(f"Discovered {len(self._approaches_cache)} approaches, "
61+
f"{len(self._plugins_cache)} plugins")
62+
63+
def _discover_approaches(self):
64+
"""Discover built-in approaches from optillm package"""
65+
approach_modules = {
66+
'mcts': ('optillm.mcts', 'chat_with_mcts'),
67+
'bon': ('optillm.bon', 'best_of_n_sampling'),
68+
'moa': ('optillm.moa', 'mixture_of_agents'),
69+
'rto': ('optillm.rto', 'round_trip_optimization'),
70+
'self_consistency': ('optillm.self_consistency', 'advanced_self_consistency_approach'),
71+
'pvg': ('optillm.pvg', 'inference_time_pv_game'),
72+
'z3': ('optillm.z3_solver', None), # Special case
73+
'rstar': ('optillm.rstar', None), # Special case
74+
'cot_reflection': ('optillm.cot_reflection', 'cot_reflection'),
75+
'plansearch': ('optillm.plansearch', 'plansearch'),
76+
'leap': ('optillm.leap', 'leap'),
77+
're2': ('optillm.reread', 're2_approach'),
78+
'cepo': ('optillm.cepo.cepo', 'cepo'), # CEPO approach
79+
}
80+
81+
for name, (module_path, func_name) in approach_modules.items():
82+
try:
83+
module = importlib.import_module(module_path)
84+
85+
if name == 'z3':
86+
# Special handling for Z3
87+
solver_class = getattr(module, 'Z3SymPySolverSystem')
88+
self._approaches_cache[name] = lambda s, q, c, m, **kw: \
89+
solver_class(s, c, m).process_query(q)
90+
elif name == 'rstar':
91+
# Special handling for RStar
92+
rstar_class = getattr(module, 'RStar')
93+
self._approaches_cache[name] = lambda s, q, c, m, **kw: \
94+
rstar_class(s, c, m, **kw).solve(q)
95+
elif name == 'cepo':
96+
# Special handling for CEPO which needs special config
97+
cepo_func = getattr(module, func_name)
98+
# We'll pass empty CepoConfig for now - it can be enhanced later
99+
self._approaches_cache[name] = cepo_func
100+
else:
101+
if func_name:
102+
self._approaches_cache[name] = getattr(module, func_name)
103+
104+
except (ImportError, AttributeError) as e:
105+
logger.debug(f"Could not load approach '{name}': {e}")
106+
107+
def _discover_plugins(self):
108+
"""Discover available plugins dynamically"""
109+
try:
110+
import optillm
111+
import os
112+
import glob
113+
114+
# Get plugin directories
115+
package_dir = Path(optillm.__file__).parent / 'plugins'
116+
117+
# Find all Python files in plugins directory
118+
plugin_files = []
119+
if package_dir.exists():
120+
plugin_files.extend(glob.glob(str(package_dir / '*.py')))
121+
122+
for plugin_file in plugin_files:
123+
if '__pycache__' in plugin_file or '__init__' in plugin_file:
124+
continue
125+
126+
try:
127+
# Extract module name
128+
module_name = Path(plugin_file).stem
129+
130+
# Skip self
131+
if module_name == 'proxy_plugin':
132+
continue
133+
134+
# Import module
135+
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
136+
if spec and spec.loader:
137+
module = importlib.util.module_from_spec(spec)
138+
spec.loader.exec_module(module)
139+
140+
# Check if it has required attributes
141+
if hasattr(module, 'SLUG') and hasattr(module, 'run'):
142+
slug = getattr(module, 'SLUG')
143+
run_func = getattr(module, 'run')
144+
self._plugins_cache[slug] = run_func
145+
146+
except Exception as e:
147+
logger.debug(f"Could not load plugin from {plugin_file}: {e}")
148+
149+
except Exception as e:
150+
logger.debug(f"Error discovering plugins: {e}")
151+
152+
def _execute_handler(self, handler, system_prompt: str, initial_query: str,
153+
client, model: str, request_config: dict = None) -> Tuple[str, int]:
154+
"""Execute a handler function with proper signature detection"""
155+
try:
156+
# Check function signature
157+
sig = inspect.signature(handler)
158+
params = sig.parameters
159+
160+
# Build arguments based on signature
161+
args = [system_prompt, initial_query, client, model]
162+
kwargs = {}
163+
164+
# Check if handler accepts request_config
165+
if 'request_config' in params:
166+
kwargs['request_config'] = request_config
167+
168+
# Some handlers may accept additional kwargs
169+
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
170+
# Only add safe kwargs that won't conflict
171+
if request_config:
172+
# Filter out parameters that might conflict
173+
safe_kwargs = {k: v for k, v in request_config.items()
174+
if k not in ['model', 'messages', 'system_prompt', 'initial_query']}
175+
kwargs.update(safe_kwargs)
176+
177+
# Execute handler
178+
return handler(*args, **kwargs)
179+
180+
except Exception as e:
181+
logger.error(f"Error executing handler: {e}")
182+
raise

optillm/plugins/proxy/client.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
ProxyClient implementation for load balancing across multiple LLM providers.
3+
"""
4+
import time
5+
import logging
6+
import random
7+
from typing import Dict, List, Any, Optional
8+
from openai import OpenAI, AzureOpenAI
9+
from .routing import RouterFactory
10+
from .health import HealthChecker
11+
12+
logger = logging.getLogger(__name__)
13+
14+
class Provider:
15+
"""Wrapper for a provider configuration and client"""
16+
def __init__(self, config: Dict):
17+
self.name = config['name']
18+
self.base_url = config['base_url']
19+
self.api_key = config['api_key']
20+
self.weight = config.get('weight', 1)
21+
self.fallback_only = config.get('fallback_only', False)
22+
self.model_map = config.get('model_map', {})
23+
self._client = None
24+
self.is_healthy = True
25+
self.last_error = None
26+
self.latencies = [] # Track recent latencies
27+
28+
@property
29+
def client(self):
30+
"""Lazy initialization of OpenAI client"""
31+
if not self._client:
32+
if 'azure' in self.base_url.lower():
33+
# Handle Azure OpenAI
34+
self._client = AzureOpenAI(
35+
api_key=self.api_key,
36+
azure_endpoint=self.base_url,
37+
api_version="2024-02-01"
38+
)
39+
else:
40+
# Standard OpenAI-compatible client
41+
self._client = OpenAI(
42+
api_key=self.api_key,
43+
base_url=self.base_url
44+
)
45+
return self._client
46+
47+
def map_model(self, model: str) -> str:
48+
"""Map requested model to provider-specific name"""
49+
return self.model_map.get(model, model)
50+
51+
def track_latency(self, latency: float):
52+
"""Track request latency"""
53+
self.latencies.append(latency)
54+
if len(self.latencies) > 10:
55+
self.latencies.pop(0)
56+
57+
def avg_latency(self) -> float:
58+
"""Get average latency"""
59+
if not self.latencies:
60+
return 0
61+
return sum(self.latencies) / len(self.latencies)
62+
63+
class ProxyClient:
64+
"""OpenAI-compatible client that proxies to multiple providers"""
65+
66+
def __init__(self, config: Dict, fallback_client=None):
67+
self.config = config
68+
self.fallback_client = fallback_client
69+
70+
# Initialize providers
71+
self.providers = [
72+
Provider(p) for p in config.get('providers', [])
73+
]
74+
75+
# Filter out fallback-only providers for normal routing
76+
self.active_providers = [
77+
p for p in self.providers if not p.fallback_only
78+
]
79+
80+
self.fallback_providers = [
81+
p for p in self.providers if p.fallback_only
82+
]
83+
84+
# Initialize router
85+
strategy = config.get('routing', {}).get('strategy', 'round_robin')
86+
self.router = RouterFactory.create(strategy, self.active_providers)
87+
88+
# Initialize health checker
89+
health_config = config.get('routing', {}).get('health_check', {})
90+
self.health_checker = HealthChecker(
91+
providers=self.providers,
92+
enabled=health_config.get('enabled', True),
93+
interval=health_config.get('interval', 30),
94+
timeout=health_config.get('timeout', 5)
95+
)
96+
97+
# Start health checking
98+
self.health_checker.start()
99+
100+
# Monitoring settings
101+
monitoring = config.get('monitoring', {})
102+
self.track_latency = monitoring.get('track_latency', True)
103+
self.track_errors = monitoring.get('track_errors', True)
104+
105+
# Create chat namespace
106+
self.chat = self._Chat(self)
107+
108+
class _Chat:
109+
def __init__(self, proxy_client):
110+
self.proxy_client = proxy_client
111+
self.completions = proxy_client._Completions(proxy_client)
112+
113+
class _Completions:
114+
def __init__(self, proxy_client):
115+
self.proxy_client = proxy_client
116+
117+
def create(self, **kwargs):
118+
"""Create completion with load balancing and failover"""
119+
model = kwargs.get('model', 'unknown')
120+
attempted_providers = set()
121+
errors = []
122+
123+
# Get healthy providers
124+
healthy_providers = [
125+
p for p in self.proxy_client.active_providers
126+
if p.is_healthy
127+
]
128+
129+
if not healthy_providers:
130+
logger.warning("No healthy providers, trying fallback providers")
131+
healthy_providers = self.proxy_client.fallback_providers
132+
133+
# Try routing through healthy providers
134+
while healthy_providers:
135+
available_providers = [p for p in healthy_providers if p not in attempted_providers]
136+
if not available_providers:
137+
break
138+
139+
provider = self.proxy_client.router.select(available_providers)
140+
141+
if not provider:
142+
break
143+
144+
attempted_providers.add(provider)
145+
146+
try:
147+
# Map model name if needed
148+
request_kwargs = kwargs.copy()
149+
request_kwargs['model'] = provider.map_model(model)
150+
151+
# Track timing
152+
start_time = time.time()
153+
154+
# Make request
155+
logger.debug(f"Routing to {provider.name}")
156+
response = provider.client.chat.completions.create(**request_kwargs)
157+
158+
# Track success
159+
latency = time.time() - start_time
160+
if self.proxy_client.track_latency:
161+
provider.track_latency(latency)
162+
163+
logger.info(f"Request succeeded via {provider.name} in {latency:.2f}s")
164+
return response
165+
166+
except Exception as e:
167+
logger.error(f"Provider {provider.name} failed: {e}")
168+
errors.append((provider.name, str(e)))
169+
170+
# Mark provider as unhealthy
171+
if self.proxy_client.track_errors:
172+
provider.is_healthy = False
173+
provider.last_error = str(e)
174+
175+
# All providers failed, try fallback client
176+
if self.proxy_client.fallback_client:
177+
logger.warning("All proxy providers failed, using fallback client")
178+
try:
179+
return self.proxy_client.fallback_client.chat.completions.create(**kwargs)
180+
except Exception as e:
181+
errors.append(("fallback_client", str(e)))
182+
183+
# Complete failure
184+
error_msg = f"All providers failed. Errors: {errors}"
185+
logger.error(error_msg)
186+
raise Exception(error_msg)

0 commit comments

Comments
 (0)