|
1 | | -import os, json, asyncio, requests, boto3 |
| 1 | +import os, json, asyncio, requests, boto3, re, datetime as dt |
2 | 2 | from typing import List, Tuple, Dict |
3 | 3 | from typing import TypedDict |
4 | 4 | from botocore.config import Config |
@@ -42,6 +42,134 @@ def key(name: str) -> Tuple[float, str]: |
42 | 42 | return sorted(out, key=key, reverse=True) |
43 | 43 |
|
44 | 44 |
|
| 45 | +# ──────────────────────────────────────────────────────────────── |
| 46 | +# OpenAI helpers |
| 47 | +# ──────────────────────────────────────────────────────────────── |
| 48 | +def list_openai_models() -> List[str]: |
| 49 | + """Return curated list of latest OpenAI text generation models, sorted by release date (newest first).""" |
| 50 | + # Based on research - only most relevant latest text-to-text models |
| 51 | + models = [ |
| 52 | + "gpt-4.1", # Latest GPT-4.1 series (April 2025) |
| 53 | + "gpt-4.1-mini", |
| 54 | + "gpt-4.1-nano", |
| 55 | + "o3", # Latest reasoning models (April 2025) |
| 56 | + "o4-mini", |
| 57 | + "o3-mini", # January 2025 |
| 58 | + "o1", # December 2024 |
| 59 | + "gpt-4o", # November 2024 |
| 60 | + "gpt-4o-mini", # July 2024 |
| 61 | + "gpt-4-turbo", # April 2024 |
| 62 | + "gpt-3.5-turbo" # Legacy but still widely used |
| 63 | + ] |
| 64 | + return models |
| 65 | + |
| 66 | + |
| 67 | +async def _probe_openai(model_name: str) -> Tuple[str, bool]: |
| 68 | + """Test if OpenAI model responds within timeout.""" |
| 69 | + try: |
| 70 | + import openai |
| 71 | + client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| 72 | + |
| 73 | + await asyncio.wait_for( |
| 74 | + asyncio.to_thread( |
| 75 | + client.chat.completions.create, |
| 76 | + model=model_name, |
| 77 | + messages=[{"role": "user", "content": "Health check"}], |
| 78 | + max_tokens=5 |
| 79 | + ), |
| 80 | + timeout=5 |
| 81 | + ) |
| 82 | + return model_name, True |
| 83 | + except Exception: |
| 84 | + return model_name, False |
| 85 | + |
| 86 | + |
| 87 | +async def health_openai(models: List[str], concurrency: int = 5) -> Tuple[List[str], List[str]]: |
| 88 | + """Health check OpenAI models with concurrency control.""" |
| 89 | + if not os.getenv("OPENAI_API_KEY"): |
| 90 | + return [], models # All disabled if no API key |
| 91 | + |
| 92 | + sem = asyncio.Semaphore(concurrency) |
| 93 | + |
| 94 | + async def bound(model: str): |
| 95 | + async with sem: |
| 96 | + return await _probe_openai(model) |
| 97 | + |
| 98 | + results = await asyncio.gather(*(bound(m) for m in models)) |
| 99 | + enabled = [m for m, ok in results if ok] |
| 100 | + disabled = [m for m, ok in results if not ok] |
| 101 | + return enabled, disabled |
| 102 | + |
| 103 | + |
| 104 | +# ──────────────────────────────────────────────────────────────── |
| 105 | +# Gemini helpers |
| 106 | +# ──────────────────────────────────────────────────────────────── |
| 107 | +def list_gemini_models() -> List[str]: |
| 108 | + """Return curated list of latest Gemini text generation models, sorted by release date (newest first).""" |
| 109 | + # Based on research - only most relevant latest text-to-text models |
| 110 | + models = [ |
| 111 | + "gemini-2.5-pro", # June 2025 - most powerful thinking model |
| 112 | + "gemini-2.5-flash", # June 2025 - best price-performance |
| 113 | + "gemini-2.5-flash-lite", # June 2025 - cost-efficient |
| 114 | + "gemini-2.0-flash", # February 2025 - next-gen features |
| 115 | + "gemini-2.0-flash-lite", # February 2025 - low latency |
| 116 | + "gemini-1.5-pro", # September 2024 - complex reasoning |
| 117 | + "gemini-1.5-flash", # September 2024 - fast & versatile |
| 118 | + "gemini-1.5-flash-8b" # October 2024 - lightweight |
| 119 | + ] |
| 120 | + return models |
| 121 | + |
| 122 | + |
| 123 | +def _extract_gemini_date(model: str) -> dt.datetime: |
| 124 | + """Extract release date from Gemini model name for sorting.""" |
| 125 | + # Gemini models follow pattern: gemini-{version}-{variant} |
| 126 | + if "2.5" in model: |
| 127 | + return dt.datetime(2025, 6, 1) # June 2025 |
| 128 | + elif "2.0" in model: |
| 129 | + return dt.datetime(2025, 2, 1) # February 2025 |
| 130 | + elif "1.5" in model: |
| 131 | + if "flash-8b" in model: |
| 132 | + return dt.datetime(2024, 10, 1) # October 2024 |
| 133 | + return dt.datetime(2024, 9, 1) # September 2024 |
| 134 | + return dt.datetime.min |
| 135 | + |
| 136 | + |
| 137 | +async def _probe_gemini(model_name: str) -> Tuple[str, bool]: |
| 138 | + """Test if Gemini model responds within timeout.""" |
| 139 | + try: |
| 140 | + from google import genai |
| 141 | + client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) |
| 142 | + |
| 143 | + await asyncio.wait_for( |
| 144 | + asyncio.to_thread( |
| 145 | + client.models.generate_content, |
| 146 | + model=model_name, |
| 147 | + contents="Health check" |
| 148 | + ), |
| 149 | + timeout=5 |
| 150 | + ) |
| 151 | + return model_name, True |
| 152 | + except Exception: |
| 153 | + return model_name, False |
| 154 | + |
| 155 | + |
| 156 | +async def health_gemini(models: List[str], concurrency: int = 5) -> Tuple[List[str], List[str]]: |
| 157 | + """Health check Gemini models with concurrency control.""" |
| 158 | + if not os.getenv("GEMINI_API_KEY"): |
| 159 | + return [], models # All disabled if no API key |
| 160 | + |
| 161 | + sem = asyncio.Semaphore(concurrency) |
| 162 | + |
| 163 | + async def bound(model: str): |
| 164 | + async with sem: |
| 165 | + return await _probe_gemini(model) |
| 166 | + |
| 167 | + results = await asyncio.gather(*(bound(m) for m in models)) |
| 168 | + enabled = [m for m, ok in results if ok] |
| 169 | + disabled = [m for m, ok in results if not ok] |
| 170 | + return enabled, disabled |
| 171 | + |
| 172 | + |
45 | 173 | # ──────────────────────────────────────────────────────────────── |
46 | 174 | # Bedrock helpers |
47 | 175 | # ──────────────────────────────────────────────────────────────── |
@@ -196,17 +324,36 @@ async def bound(p: _CaiiPair): |
196 | 324 | # Single orchestrator used by the api endpoint |
197 | 325 | # ──────────────────────────────────────────────────────────────── |
198 | 326 | async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]: |
199 | | - # Bedrock first |
| 327 | + """Collect and health-check models from all providers.""" |
| 328 | + |
| 329 | + # Bedrock |
200 | 330 | bedrock_all = list_bedrock_models() |
201 | 331 | bedrock_enabled, bedrock_disabled = await health_bedrock(bedrock_all) |
202 | 332 |
|
| 333 | + # OpenAI |
| 334 | + openai_all = list_openai_models() |
| 335 | + openai_enabled, openai_disabled = await health_openai(openai_all) |
| 336 | + |
| 337 | + # Gemini |
| 338 | + gemini_all = list_gemini_models() |
| 339 | + gemini_enabled, gemini_disabled = await health_gemini(gemini_all) |
| 340 | + |
203 | 341 | catalog: Dict[str, Dict[str, List[str]]] = { |
204 | 342 | "aws_bedrock": { |
205 | 343 | "enabled": bedrock_enabled, |
206 | 344 | "disabled": bedrock_disabled, |
| 345 | + }, |
| 346 | + "openai": { |
| 347 | + "enabled": openai_enabled, |
| 348 | + "disabled": openai_disabled, |
| 349 | + }, |
| 350 | + "google_gemini": { |
| 351 | + "enabled": gemini_enabled, |
| 352 | + "disabled": gemini_disabled, |
207 | 353 | } |
208 | 354 | } |
209 | 355 |
|
| 356 | + # CAII (only on-cluster) |
210 | 357 | if os.getenv("CDSW_PROJECT_ID", "local") != "local": |
211 | 358 | caii_all = list_caii_models() |
212 | 359 | caii_enabled, caii_disabled = await health_caii(caii_all) |
|
0 commit comments