@@ -149,7 +149,7 @@ def restart_application():
149149 runtime_identifier = runtime_identifier
150150)
151151
152- #*************Comment this when running locally********************************************
152+ # # *************Comment this when running locally********************************************
153153
154154# Add these models
155155class StudioUpgradeStatus (BaseModel ):
@@ -472,6 +472,40 @@ async def create_custom_prompt(request: CustomPromptRequest):
472472 return {"generated_prompt" :prompt_gen }
473473 except Exception as e :
474474 raise HTTPException (status_code = 500 , detail = str (e ))
475+
476+
477+ def sort_unique_models (model_list ):
478+ def get_sort_key (model_name ):
479+ # Strip any provider prefix
480+ name = model_name .split ('.' )[- 1 ] if '.' in model_name else model_name
481+ parts = name .split ('-' )
482+
483+ # Extract version
484+ version = '0'
485+ for part in parts :
486+ if part .startswith ('v' ) and any (c .isdigit () for c in part ):
487+ version = part [1 :]
488+ if ':' in version :
489+ version = version .split (':' )[0 ]
490+
491+ # Extract date
492+ date = '00000000'
493+ for part in parts :
494+ if len (part ) == 8 and part .isdigit ():
495+ date = part
496+
497+ return (float (version ), date )
498+
499+ # Remove duplicates while preserving original names
500+ unique_models = set ()
501+ filtered_models = []
502+ for model in model_list :
503+ base_name = model .split ('.' )[- 1 ] if '.' in model else model
504+ if base_name not in unique_models :
505+ unique_models .add (base_name )
506+ filtered_models .append (model )
507+
508+ return sorted (filtered_models , key = get_sort_key , reverse = True )
475509
476510@app .get ("/model/model_ID" , include_in_schema = True )
477511async def get_model_id ():
@@ -483,37 +517,75 @@ async def get_model_id():
483517 "max_attempts" : 2 ,
484518 "mode" : "standard" ,
485519 },
486- )
487- client_s = boto3 .client (service_name = "bedrock" , region_name = region , config = retry_config )
488- response = client_s .list_foundation_models ()
489- all_models = response ['modelSummaries' ]
490-
491- mod_list = [m ['modelId' ]
492- for m in all_models
493- if 'ON_DEMAND' in m ['inferenceTypesSupported' ]
494- and 'TEXT' in m ['inputModalities' ]
495- and 'TEXT' in m ['outputModalities' ]
496- and m ['providerName' ] in ['Anthropic' , 'Meta' , 'Mistral AI' ]]
520+ )
521+
497522
498- inference_models = client_s .list_inference_profiles ()
499- inference_mod_list = [m ['inferenceProfileId' ] for m in inference_models ['inferenceProfileSummaries' ]
500- if ("meta" in m ['inferenceProfileId' ]) or ("anthropic" in m ['inferenceProfileId' ]) or ("mistral" in m ['inferenceProfileId' ]) ]
501- bedrock_list = inference_mod_list + mod_list
523+
524+ try :
525+ client_s = boto3 .client (service_name = "bedrock" , region_name = region , config = retry_config )
526+
527+ # Get foundation models
528+ response = client_s .list_foundation_models ()
529+ all_models = response ['modelSummaries' ]
530+
531+ mod_list = [m ['modelId' ]
532+ for m in all_models
533+ if 'ON_DEMAND' in m ['inferenceTypesSupported' ]
534+ and 'TEXT' in m ['inputModalities' ]
535+ and 'TEXT' in m ['outputModalities' ]
536+ and m ['providerName' ] in ['Anthropic' , 'Meta' , 'Mistral AI' ]]
537+
538+ # Get inference profiles with comprehensive error handling
539+ try :
540+ inference_models = client_s .list_inference_profiles ()
541+ inference_mod_list = []
542+ if 'inferenceProfileSummaries' in inference_models :
543+ inference_mod_list = [
544+ m ['inferenceProfileId' ]
545+ for m in inference_models ['inferenceProfileSummaries' ]
546+ if any (provider in m ['inferenceProfileId' ].lower ()
547+ for provider in ['meta' , 'anthropic' , 'mistral' ])
548+ ]
549+ except client_s .exceptions .ResourceNotFoundException :
550+ inference_mod_list = []
551+ except client_s .exceptions .ValidationException as e :
552+ print (f"Validation error: { str (e )} " )
553+ inference_mod_list = []
554+ except client_s .exceptions .AccessDeniedException as e :
555+ print (f"Access denied: { str (e )} " )
556+ inference_mod_list = []
557+ except client_s .exceptions .ThrottlingException as e :
558+ print (f"Request throttled: { str (e )} " )
559+ inference_mod_list = []
560+ except client_s .exceptions .InternalServerException as e :
561+ print (f"Bedrock internal error: { str (e )} " )
562+ inference_mod_list = []
563+
564+ # Combine and sort the lists
565+ bedrock_list = sort_unique_models (inference_mod_list + mod_list )
566+
567+ models = {
568+ "aws_bedrock" : bedrock_list ,
569+ "CAII" : ['meta/llama-3_1-8b-instruct' , 'mistralai/mistral-7b-instruct-v0.3' ]
570+ }
571+
572+ return {"models" : models }
502573
503- # mod_list_wp = {}
504- # for m in all_models:
505- # if ('ON_DEMAND' in m['inferenceTypesSupported']
506- # and 'TEXT' in m['inputModalities'] and 'TEXT' in m['outputModalities']):
507- # provider = m['providerName']
508- # if provider not in mod_list_wp:
509- # mod_list_wp[provider] = []
510- # mod_list_wp[provider].append(m['modelId'])
511-
512- models = {"aws_bedrock" :bedrock_list ,
513- "CAII" : ['meta/llama-3_1-8b-instruct' , 'mistralai/mistral-7b-instruct-v0.3' ]
514- }
515-
516- return {"models" :models }
574+ except client_s .exceptions .ValidationException as e :
575+ print (f"Validation error: { str (e )} " )
576+ raise
577+ except client_s .exceptions .AccessDeniedException as e :
578+ print (f"Access denied: { str (e )} " )
579+ raise
580+ except client_s .exceptions .ThrottlingException as e :
581+ print (f"Request throttled: { str (e )} " )
582+ raise
583+ except client_s .exceptions .InternalServerException as e :
584+ print (f"Bedrock internal error: { str (e )} " )
585+ raise
586+ except Exception as e :
587+ print (f"Unexpected error occurred: { str (e )} " )
588+ raise
517589
518590@app .get ("/use-cases" , include_in_schema = True )
519591async def get_use_cases ():
0 commit comments