Skip to content

Commit 04c3170

Browse files
authored
Merge pull request #85 | Database upgrade for partial success of long running job.
Database upgrade for partial success of long running job.
2 parents fb45900 + 983afce commit 04c3170

File tree

3 files changed

+300
-74
lines changed

3 files changed

+300
-74
lines changed

app/core/model_endpoints.py

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os, json, asyncio, requests, boto3
1+
import os, json, asyncio, requests, boto3, re, datetime as dt
22
from typing import List, Tuple, Dict
33
from typing import TypedDict
44
from botocore.config import Config
@@ -42,6 +42,134 @@ def key(name: str) -> Tuple[float, str]:
4242
return sorted(out, key=key, reverse=True)
4343

4444

45+
# ────────────────────────────────────────────────────────────────
46+
# OpenAI helpers
47+
# ────────────────────────────────────────────────────────────────
48+
def list_openai_models() -> List[str]:
49+
"""Return curated list of latest OpenAI text generation models, sorted by release date (newest first)."""
50+
# Based on research - only most relevant latest text-to-text models
51+
models = [
52+
"gpt-4.1", # Latest GPT-4.1 series (April 2025)
53+
"gpt-4.1-mini",
54+
"gpt-4.1-nano",
55+
"o3", # Latest reasoning models (April 2025)
56+
"o4-mini",
57+
"o3-mini", # January 2025
58+
"o1", # December 2024
59+
"gpt-4o", # November 2024
60+
"gpt-4o-mini", # July 2024
61+
"gpt-4-turbo", # April 2024
62+
"gpt-3.5-turbo" # Legacy but still widely used
63+
]
64+
return models
65+
66+
67+
async def _probe_openai(model_name: str) -> Tuple[str, bool]:
68+
"""Test if OpenAI model responds within timeout."""
69+
try:
70+
import openai
71+
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
72+
73+
await asyncio.wait_for(
74+
asyncio.to_thread(
75+
client.chat.completions.create,
76+
model=model_name,
77+
messages=[{"role": "user", "content": "Health check"}],
78+
max_tokens=5
79+
),
80+
timeout=5
81+
)
82+
return model_name, True
83+
except Exception:
84+
return model_name, False
85+
86+
87+
async def health_openai(models: List[str], concurrency: int = 5) -> Tuple[List[str], List[str]]:
88+
"""Health check OpenAI models with concurrency control."""
89+
if not os.getenv("OPENAI_API_KEY"):
90+
return [], models # All disabled if no API key
91+
92+
sem = asyncio.Semaphore(concurrency)
93+
94+
async def bound(model: str):
95+
async with sem:
96+
return await _probe_openai(model)
97+
98+
results = await asyncio.gather(*(bound(m) for m in models))
99+
enabled = [m for m, ok in results if ok]
100+
disabled = [m for m, ok in results if not ok]
101+
return enabled, disabled
102+
103+
104+
# ────────────────────────────────────────────────────────────────
105+
# Gemini helpers
106+
# ────────────────────────────────────────────────────────────────
107+
def list_gemini_models() -> List[str]:
108+
"""Return curated list of latest Gemini text generation models, sorted by release date (newest first)."""
109+
# Based on research - only most relevant latest text-to-text models
110+
models = [
111+
"gemini-2.5-pro", # June 2025 - most powerful thinking model
112+
"gemini-2.5-flash", # June 2025 - best price-performance
113+
"gemini-2.5-flash-lite", # June 2025 - cost-efficient
114+
"gemini-2.0-flash", # February 2025 - next-gen features
115+
"gemini-2.0-flash-lite", # February 2025 - low latency
116+
"gemini-1.5-pro", # September 2024 - complex reasoning
117+
"gemini-1.5-flash", # September 2024 - fast & versatile
118+
"gemini-1.5-flash-8b" # October 2024 - lightweight
119+
]
120+
return models
121+
122+
123+
def _extract_gemini_date(model: str) -> dt.datetime:
124+
"""Extract release date from Gemini model name for sorting."""
125+
# Gemini models follow pattern: gemini-{version}-{variant}
126+
if "2.5" in model:
127+
return dt.datetime(2025, 6, 1) # June 2025
128+
elif "2.0" in model:
129+
return dt.datetime(2025, 2, 1) # February 2025
130+
elif "1.5" in model:
131+
if "flash-8b" in model:
132+
return dt.datetime(2024, 10, 1) # October 2024
133+
return dt.datetime(2024, 9, 1) # September 2024
134+
return dt.datetime.min
135+
136+
137+
async def _probe_gemini(model_name: str) -> Tuple[str, bool]:
138+
"""Test if Gemini model responds within timeout."""
139+
try:
140+
from google import genai
141+
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
142+
143+
await asyncio.wait_for(
144+
asyncio.to_thread(
145+
client.models.generate_content,
146+
model=model_name,
147+
contents="Health check"
148+
),
149+
timeout=5
150+
)
151+
return model_name, True
152+
except Exception:
153+
return model_name, False
154+
155+
156+
async def health_gemini(models: List[str], concurrency: int = 5) -> Tuple[List[str], List[str]]:
157+
"""Health check Gemini models with concurrency control."""
158+
if not os.getenv("GEMINI_API_KEY"):
159+
return [], models # All disabled if no API key
160+
161+
sem = asyncio.Semaphore(concurrency)
162+
163+
async def bound(model: str):
164+
async with sem:
165+
return await _probe_gemini(model)
166+
167+
results = await asyncio.gather(*(bound(m) for m in models))
168+
enabled = [m for m, ok in results if ok]
169+
disabled = [m for m, ok in results if not ok]
170+
return enabled, disabled
171+
172+
45173
# ────────────────────────────────────────────────────────────────
46174
# Bedrock helpers
47175
# ────────────────────────────────────────────────────────────────
@@ -196,17 +324,36 @@ async def bound(p: _CaiiPair):
196324
# Single orchestrator used by the api endpoint
197325
# ────────────────────────────────────────────────────────────────
198326
async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]:
199-
# Bedrock first
327+
"""Collect and health-check models from all providers."""
328+
329+
# Bedrock
200330
bedrock_all = list_bedrock_models()
201331
bedrock_enabled, bedrock_disabled = await health_bedrock(bedrock_all)
202332

333+
# OpenAI
334+
openai_all = list_openai_models()
335+
openai_enabled, openai_disabled = await health_openai(openai_all)
336+
337+
# Gemini
338+
gemini_all = list_gemini_models()
339+
gemini_enabled, gemini_disabled = await health_gemini(gemini_all)
340+
203341
catalog: Dict[str, Dict[str, List[str]]] = {
204342
"aws_bedrock": {
205343
"enabled": bedrock_enabled,
206344
"disabled": bedrock_disabled,
345+
},
346+
"openai": {
347+
"enabled": openai_enabled,
348+
"disabled": openai_disabled,
349+
},
350+
"google_gemini": {
351+
"enabled": gemini_enabled,
352+
"disabled": gemini_disabled,
207353
}
208354
}
209355

356+
# CAII (only on-cluster)
210357
if os.getenv("CDSW_PROJECT_ID", "local") != "local":
211358
caii_all = list_caii_models()
212359
caii_enabled, caii_disabled = await health_caii(caii_all)

app/core/model_handlers.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
class UnifiedModelHandler:
2323
"""Unified handler for all model types using Bedrock's converse API"""
2424

25+
# Add timeout constants
26+
OPENAI_CONNECT_TIMEOUT = 5.0
27+
OPENAI_READ_TIMEOUT = 3600.0 # 1 hour, same as AWS Bedrock
28+
29+
GEMINI_TIMEOUT = 3600.0 # 1 hour timeout for Gemini
30+
2531
def __init__(self, model_id: str, bedrock_client=None, model_params: Optional[ModelParameters] = None, inference_type = "aws_bedrock", caii_endpoint:Optional[str]=None, custom_p = False):
2632
"""
2733
Initialize the model handler
@@ -43,6 +49,10 @@ def __init__(self, model_id: str, bedrock_client=None, model_params: Optional[Mo
4349
self.MAX_RETRIES = 2
4450
self.BASE_DELAY = 3 # Initial delay of 3 seconds
4551
self.MULTIPLIER = 1.5 # AWS Step Functions multiplier
52+
53+
# Add timeout configuration
54+
self.CONNECT_TIMEOUT = 5 # 5 seconds connect timeout
55+
self.READ_TIMEOUT = 3600 # 1 hour read timeout (same as AWS Bedrock)
4656

4757
def _exponential_backoff(self, retry_count: int) -> None:
4858
"""AWS Step Functions style backoff: 3s -> 4.5s -> 6.75s"""
@@ -251,7 +261,7 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
251261
if error_code == 'ValidationException':
252262
if 'model identifier is invalid' in error_message:
253263
raise InvalidModelError(self.model_id,error_message )
254-
elif "on-demand throughput isnt supported" in error_message:
264+
elif "on-demand throughput isn't supported" in error_message:
255265
print("Hi")
256266
raise InvalidModelError(self.model_id, error_message)
257267

@@ -295,9 +305,20 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
295305
# ---------- OpenAI -------------------------------------------------------
296306
def _handle_openai_request(self, prompt: str):
297307
try:
308+
import httpx
309+
from openai import OpenAI
310+
311+
# Configure timeout for OpenAI client (OpenAI v1.57.2)
312+
timeout_config = httpx.Timeout(
313+
connect=self.OPENAI_CONNECT_TIMEOUT,
314+
read=self.OPENAI_READ_TIMEOUT,
315+
write=10.0,
316+
pool=5.0
317+
)
318+
298319
client = OpenAI(
299-
api_key=os.getenv("OPENAI_API_KEY"),
300-
base_url=os.getenv("OPENAI_API_BASE", None) or None,
320+
api_key=os.getenv('OPENAI_API_KEY'),
321+
timeout=timeout_config
301322
)
302323
completion = client.chat.completions.create(
303324
model=self.model_id,
@@ -329,6 +350,9 @@ def _handle_gemini_request(self, prompt: str):
329350
"temperature": self.model_params.temperature,
330351
"top_p": self.model_params.top_p,
331352
},
353+
request_options={
354+
"timeout": self.GEMINI_TIMEOUT # Use the dedicated Gemini timeout constant
355+
}
332356
)
333357
text = resp.text
334358
return self._extract_json_from_text(text) if not self.custom_p else text
@@ -337,17 +361,29 @@ def _handle_gemini_request(self, prompt: str):
337361

338362

339363
def _handle_caii_request(self, prompt: str):
340-
"""Original CAII implementation"""
364+
"""CAII implementation with proper timeout configuration (uses OpenAI SDK)"""
341365
try:
342-
#API_KEY = json.load(open("/tmp/jwt"))["access_token"]
366+
import httpx
367+
from openai import OpenAI
368+
343369
API_KEY = _get_caii_token()
344370
MODEL_ID = self.model_id
345371
caii_endpoint = self.caii_endpoint
346372

347373
caii_endpoint = caii_endpoint.removesuffix('/chat/completions')
374+
375+
# Configure timeout for CAII client (same as OpenAI since it uses OpenAI SDK v1.57.2)
376+
timeout_config = httpx.Timeout(
377+
connect=self.OPENAI_CONNECT_TIMEOUT,
378+
read=self.OPENAI_READ_TIMEOUT,
379+
write=10.0,
380+
pool=5.0
381+
)
382+
348383
client_ca = OpenAI(
349384
base_url=caii_endpoint,
350385
api_key=API_KEY,
386+
timeout=timeout_config # Use the comprehensive timeout configuration
351387
)
352388

353389
completion = client_ca.chat.completions.create(

0 commit comments

Comments
 (0)