|
5 | 5 | with health monitoring, failover, and support for wrapping other approaches. |
6 | 6 | """ |
7 | 7 | import logging |
8 | | -from typing import Tuple, Optional |
| 8 | +import threading |
| 9 | +from typing import Tuple, Optional, Dict |
9 | 10 | from optillm.plugins.proxy.config import ProxyConfig |
10 | 11 | from optillm.plugins.proxy.client import ProxyClient |
11 | 12 | from optillm.plugins.proxy.approach_handler import ApproachHandler |
|
21 | 22 | # Global proxy client cache to maintain state between requests |
22 | 23 | _proxy_client_cache = {} |
23 | 24 |
|
| 25 | +# Global cache for system message support per provider-model combination |
| 26 | +_system_message_support_cache: Dict[str, bool] = {} |
| 27 | +_cache_lock = threading.RLock() |
| 28 | + |
| 29 | +def _test_system_message_support(proxy_client, model: str) -> bool: |
| 30 | + """ |
| 31 | + Test if a model supports system messages by making a minimal test request. |
| 32 | + Returns True if supported, False otherwise. |
| 33 | + """ |
| 34 | + try: |
| 35 | + # Try a minimal system message request |
| 36 | + test_response = proxy_client.chat.completions.create( |
| 37 | + model=model, |
| 38 | + messages=[ |
| 39 | + {"role": "system", "content": "test"}, |
| 40 | + {"role": "user", "content": "hi"} |
| 41 | + ], |
| 42 | + max_tokens=1, # Minimal token generation |
| 43 | + temperature=0 |
| 44 | + ) |
| 45 | + return True |
| 46 | + except Exception as e: |
| 47 | + error_msg = str(e).lower() |
| 48 | + # Check for known system message rejection patterns |
| 49 | + if any(pattern in error_msg for pattern in [ |
| 50 | + "developer instruction", |
| 51 | + "system message", |
| 52 | + "not enabled", |
| 53 | + "not supported" |
| 54 | + ]): |
| 55 | + logger.info(f"Model {model} does not support system messages: {str(e)[:100]}") |
| 56 | + return False |
| 57 | + else: |
| 58 | + # If it's a different error, assume system messages are supported |
| 59 | + # but something else went wrong (rate limit, timeout, etc.) |
| 60 | + logger.debug(f"System message test failed for {model}, assuming supported: {str(e)[:100]}") |
| 61 | + return True |
| 62 | + |
| 63 | +def _get_system_message_support(proxy_client, model: str) -> bool: |
| 64 | + """ |
| 65 | + Get cached system message support status, testing if not cached. |
| 66 | + Thread-safe with locking. |
| 67 | + """ |
| 68 | + # Create a unique cache key based on model and base_url |
| 69 | + cache_key = f"{getattr(proxy_client, '_base_identifier', 'default')}:{model}" |
| 70 | + |
| 71 | + with _cache_lock: |
| 72 | + if cache_key not in _system_message_support_cache: |
| 73 | + logger.debug(f"Testing system message support for {model}") |
| 74 | + _system_message_support_cache[cache_key] = _test_system_message_support(proxy_client, model) |
| 75 | + |
| 76 | + return _system_message_support_cache[cache_key] |
| 77 | + |
| 78 | +def _format_messages_for_model(system_prompt: str, initial_query: str, |
| 79 | + supports_system_messages: bool) -> list: |
| 80 | + """ |
| 81 | + Format messages based on whether the model supports system messages. |
| 82 | + """ |
| 83 | + if supports_system_messages: |
| 84 | + return [ |
| 85 | + {"role": "system", "content": system_prompt}, |
| 86 | + {"role": "user", "content": initial_query} |
| 87 | + ] |
| 88 | + else: |
| 89 | + # Merge system prompt into user message |
| 90 | + if system_prompt.strip(): |
| 91 | + combined_message = f"{system_prompt}\n\nUser: {initial_query}" |
| 92 | + else: |
| 93 | + combined_message = initial_query |
| 94 | + |
| 95 | + return [{"role": "user", "content": combined_message}] |
| 96 | + |
24 | 97 | def run(system_prompt: str, initial_query: str, client, model: str, |
25 | 98 | request_config: dict = None) -> Tuple[str, int]: |
26 | 99 | """ |
@@ -119,14 +192,21 @@ def run(system_prompt: str, initial_query: str, client, model: str, |
119 | 192 | logger.info(f"Proxy routing approach/plugin: {potential_approach}") |
120 | 193 | return result |
121 | 194 |
|
122 | | - # Direct proxy execution |
| 195 | + # Direct proxy execution with dynamic system message support detection |
123 | 196 | logger.info(f"Direct proxy routing for model: {model}") |
| 197 | + |
| 198 | + # Test and cache system message support for this model |
| 199 | + supports_system_messages = _get_system_message_support(proxy_client, model) |
| 200 | + |
| 201 | + # Format messages based on system message support |
| 202 | + messages = _format_messages_for_model(system_prompt, initial_query, supports_system_messages) |
| 203 | + |
| 204 | + if not supports_system_messages: |
| 205 | + logger.info(f"Using fallback message formatting for {model} (no system message support)") |
| 206 | + |
124 | 207 | response = proxy_client.chat.completions.create( |
125 | 208 | model=model, |
126 | | - messages=[ |
127 | | - {"role": "system", "content": system_prompt}, |
128 | | - {"role": "user", "content": initial_query} |
129 | | - ], |
| 209 | + messages=messages, |
130 | 210 | **(request_config or {}) |
131 | 211 | ) |
132 | 212 |
|
|
0 commit comments