Skip to content

Commit 68a949e

Browse files
authored
Merge pull request #45 from cloudera/dev
Model ID in order of latest release
2 parents 7ca43cc + ddf7238 commit 68a949e

File tree

7 files changed

+113
-37
lines changed

7 files changed

+113
-37
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ coverage.xml
6868
.nox/
6969
.pytest_cache
7070
#old code
71-
app/frontend
71+
app/frontend/
72+
app/launch_streamlit.py

app/core/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,4 +294,5 @@ def caii_check(caii_endpoint):
294294
caii_endpoint = caii_endpoint + "/models"
295295
response = requests.get(caii_endpoint, headers=headers, timeout=3) # Will raise RequestException if fails
296296

297-
return response
297+
return response
298+

app/main.py

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def restart_application():
149149
runtime_identifier=runtime_identifier
150150
)
151151

152-
#*************Comment this when running locally********************************************
152+
# #*************Comment this when running locally********************************************
153153

154154
# Add these models
155155
class StudioUpgradeStatus(BaseModel):
@@ -472,6 +472,40 @@ async def create_custom_prompt(request: CustomPromptRequest):
472472
return {"generated_prompt":prompt_gen}
473473
except Exception as e:
474474
raise HTTPException(status_code=500, detail=str(e))
475+
476+
477+
def sort_unique_models(model_list):
478+
def get_sort_key(model_name):
479+
# Strip any provider prefix
480+
name = model_name.split('.')[-1] if '.' in model_name else model_name
481+
parts = name.split('-')
482+
483+
# Extract version
484+
version = '0'
485+
for part in parts:
486+
if part.startswith('v') and any(c.isdigit() for c in part):
487+
version = part[1:]
488+
if ':' in version:
489+
version = version.split(':')[0]
490+
491+
# Extract date
492+
date = '00000000'
493+
for part in parts:
494+
if len(part) == 8 and part.isdigit():
495+
date = part
496+
497+
return (float(version), date)
498+
499+
# Remove duplicates while preserving original names
500+
unique_models = set()
501+
filtered_models = []
502+
for model in model_list:
503+
base_name = model.split('.')[-1] if '.' in model else model
504+
if base_name not in unique_models:
505+
unique_models.add(base_name)
506+
filtered_models.append(model)
507+
508+
return sorted(filtered_models, key=get_sort_key, reverse=True)
475509

476510
@app.get("/model/model_ID", include_in_schema=True)
477511
async def get_model_id():
@@ -483,37 +517,75 @@ async def get_model_id():
483517
"max_attempts": 2,
484518
"mode": "standard",
485519
},
486-
)
487-
client_s = boto3.client(service_name="bedrock", region_name=region, config=retry_config)
488-
response = client_s.list_foundation_models()
489-
all_models = response['modelSummaries']
490-
491-
mod_list = [m['modelId']
492-
for m in all_models
493-
if 'ON_DEMAND' in m['inferenceTypesSupported']
494-
and 'TEXT' in m['inputModalities']
495-
and 'TEXT' in m['outputModalities']
496-
and m['providerName'] in ['Anthropic', 'Meta', 'Mistral AI']]
520+
)
521+
497522

498-
inference_models = client_s.list_inference_profiles()
499-
inference_mod_list = [m['inferenceProfileId'] for m in inference_models['inferenceProfileSummaries']
500-
if ("meta" in m['inferenceProfileId']) or ("anthropic" in m['inferenceProfileId']) or ("mistral" in m['inferenceProfileId']) ]
501-
bedrock_list = inference_mod_list + mod_list
523+
524+
try:
525+
client_s = boto3.client(service_name="bedrock", region_name=region, config=retry_config)
526+
527+
# Get foundation models
528+
response = client_s.list_foundation_models()
529+
all_models = response['modelSummaries']
530+
531+
mod_list = [m['modelId']
532+
for m in all_models
533+
if 'ON_DEMAND' in m['inferenceTypesSupported']
534+
and 'TEXT' in m['inputModalities']
535+
and 'TEXT' in m['outputModalities']
536+
and m['providerName'] in ['Anthropic', 'Meta', 'Mistral AI']]
537+
538+
# Get inference profiles with comprehensive error handling
539+
try:
540+
inference_models = client_s.list_inference_profiles()
541+
inference_mod_list = []
542+
if 'inferenceProfileSummaries' in inference_models:
543+
inference_mod_list = [
544+
m['inferenceProfileId']
545+
for m in inference_models['inferenceProfileSummaries']
546+
if any(provider in m['inferenceProfileId'].lower()
547+
for provider in ['meta', 'anthropic', 'mistral'])
548+
]
549+
except client_s.exceptions.ResourceNotFoundException:
550+
inference_mod_list = []
551+
except client_s.exceptions.ValidationException as e:
552+
print(f"Validation error: {str(e)}")
553+
inference_mod_list = []
554+
except client_s.exceptions.AccessDeniedException as e:
555+
print(f"Access denied: {str(e)}")
556+
inference_mod_list = []
557+
except client_s.exceptions.ThrottlingException as e:
558+
print(f"Request throttled: {str(e)}")
559+
inference_mod_list = []
560+
except client_s.exceptions.InternalServerException as e:
561+
print(f"Bedrock internal error: {str(e)}")
562+
inference_mod_list = []
563+
564+
# Combine and sort the lists
565+
bedrock_list = sort_unique_models(inference_mod_list + mod_list)
566+
567+
models = {
568+
"aws_bedrock": bedrock_list,
569+
"CAII": ['meta/llama-3_1-8b-instruct', 'mistralai/mistral-7b-instruct-v0.3']
570+
}
571+
572+
return {"models": models}
502573

503-
# mod_list_wp = {}
504-
# for m in all_models:
505-
# if ('ON_DEMAND' in m['inferenceTypesSupported']
506-
# and 'TEXT' in m['inputModalities'] and 'TEXT' in m['outputModalities']):
507-
# provider = m['providerName']
508-
# if provider not in mod_list_wp:
509-
# mod_list_wp[provider] = []
510-
# mod_list_wp[provider].append(m['modelId'])
511-
512-
models = {"aws_bedrock":bedrock_list ,
513-
"CAII": ['meta/llama-3_1-8b-instruct', 'mistralai/mistral-7b-instruct-v0.3']
514-
}
515-
516-
return {"models":models}
574+
except client_s.exceptions.ValidationException as e:
575+
print(f"Validation error: {str(e)}")
576+
raise
577+
except client_s.exceptions.AccessDeniedException as e:
578+
print(f"Access denied: {str(e)}")
579+
raise
580+
except client_s.exceptions.ThrottlingException as e:
581+
print(f"Request throttled: {str(e)}")
582+
raise
583+
except client_s.exceptions.InternalServerException as e:
584+
print(f"Bedrock internal error: {str(e)}")
585+
raise
586+
except Exception as e:
587+
print(f"Unexpected error occurred: {str(e)}")
588+
raise
517589

518590
@app.get("/use-cases", include_in_schema=True)
519591
async def get_use_cases():
File renamed without changes.
File renamed without changes.
File renamed without changes.

app/services/synthesis_service.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def process_single_topic(self, topic: str, model_handler: any, request: Synthesi
126126
schema=request.schema,
127127
custom_prompt=request.custom_prompt,
128128
)
129-
129+
#print("prompt :", prompt)
130130
batch_qa_pairs = None
131131
try:
132132
batch_qa_pairs = model_handler.generate_response(prompt)
@@ -167,16 +167,16 @@ def process_single_topic(self, topic: str, model_handler: any, request: Synthesi
167167
questions_remaining -= len(valid_pairs)
168168
omit_questions = omit_questions[-100:] # Keep last 100 questions
169169
self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}")
170-
170+
print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
171171
# If all pairs were valid, skip fallback
172-
if invalid_count == 0:
172+
if invalid_count <= 0:
173173
continue
174174

175175
else:
176176
# Fall back to single processing for remaining or failed questions
177177
self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}")
178178
remaining_batch = invalid_count
179-
print("remaining_batch:", remaining_batch)
179+
print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
180180
for _ in range(remaining_batch):
181181
if questions_remaining <= 0:
182182
break
@@ -289,12 +289,14 @@ async def generate_examples(self, request: SynthesisRequest , job_name = None, i
289289
chunks = processor.process_document(path)
290290
topics.extend(chunks)
291291
#topics = topics[0:5]
292+
print("total chunks: ", len(topics))
292293
if request.num_questions<=len(topics):
293294
topics = topics[0:request.num_questions]
294295
num_questions = 1
296+
print("num_questions :", num_questions)
295297
else:
296298
num_questions = math.ceil(request.num_questions/len(topics))
297-
print(num_questions)
299+
#print(num_questions)
298300
total_count = request.num_questions
299301
else:
300302
if request.topics:

0 commit comments

Comments
 (0)