55from shared .models import ModelMetadata
66from shared .pricing_profiles import apply_pricing_overrides
77from shared .sources import _make_auth_headers , DEFAULT_TIMEOUT
8- from shared .tags import generate_model_tags
8+ from shared .tags import generate_model_tags , normalize_tags
99
1010logger = logging .getLogger (__name__ )
1111
@@ -179,6 +179,57 @@ async def fetch_litellm_models(client: httpx.AsyncClient, base_url: str, api_key
179179 raise
180180
181181
182+ def _collect_litellm_tags (model : dict ) -> list [str ]:
183+ """Collect tags across LiteLLM payload fields."""
184+ tags = model .get ("litellm_params" , {}).get ("tags" , [])
185+ model_info_tags = model .get ("model_info" , {}).get ("tags" , [])
186+ root_tags = model .get ("tags" , [])
187+ combined = list (tags or []) + list (model_info_tags or []) + list (root_tags or [])
188+ return [str (tag ).lower () for tag in combined if tag is not None ]
189+
190+
191+ def _extract_tag_value (tags : list [str ], prefix : str ) -> str | None :
192+ """Extract the first tag value for a prefix."""
193+ for tag in tags :
194+ if tag .startswith (prefix ):
195+ return tag [len (prefix ):]
196+ return None
197+
198+
199+ async def list_routing_group_deployments (config ) -> list [dict ]:
200+ """Return LiteLLM models tagged as routing groups."""
201+ if not config .litellm_base_url :
202+ return []
203+
204+ entries : list [dict ] = []
205+ async with httpx .AsyncClient () as client :
206+ litellm_models = await fetch_litellm_models (
207+ client ,
208+ config .litellm_base_url ,
209+ config .litellm_api_key ,
210+ )
211+
212+ for model in litellm_models :
213+ tags = _collect_litellm_tags (model )
214+ group_tag = next ((tag for tag in tags if tag .startswith ("routing_group:" )), None )
215+ if not group_tag :
216+ continue
217+ group_name = group_tag .split (":" , 1 )[1 ]
218+ entries .append (
219+ {
220+ "group" : group_name ,
221+ "provider" : _extract_tag_value (tags , "provider:" ) or "" ,
222+ "model_id" : _extract_tag_value (tags , "model:" ) or "" ,
223+ "model_name" : model .get ("model_name" ),
224+ "model_info_id" : model .get ("model_info" , {}).get ("id" ),
225+ "created_by" : model .get ("model_info" , {}).get ("created_by" ),
226+ "tags" : tags ,
227+ }
228+ )
229+
230+ return entries
231+
232+
182233async def push_model_to_litellm (
183234 client : httpx .AsyncClient ,
184235 base_url : str ,
@@ -187,6 +238,10 @@ async def push_model_to_litellm(
187238 model ,
188239 config = None ,
189240 session = None ,
241+ model_name_override : str | None = None ,
242+ extra_tags : list [str ] | None = None ,
243+ created_by : str = "updater" ,
244+ strip_unique_id : bool = False ,
190245):
191246 """Push a single model to LiteLLM."""
192247 # Build litellm_params
@@ -291,6 +346,13 @@ async def push_model_to_litellm(
291346 tags = [t for t in tags if t != "capability:completion" ]
292347 tags .append ("mode:chat" )
293348
349+ if strip_unique_id :
350+ tags = [t for t in tags if not t .startswith ("unique_id:" )]
351+
352+ if extra_tags :
353+ tags .extend (normalize_tags (extra_tags ))
354+ tags = normalize_tags (tags )
355+
294356 litellm_params ["tags" ] = tags
295357 model_info ["tags" ] = tags
296358
@@ -302,11 +364,11 @@ async def push_model_to_litellm(
302364 # Mark as created/updated by updater with timestamp
303365 from datetime import datetime , UTC
304366 current_time = datetime .now (UTC )
305- model_info ["created_by" ] = "updater"
367+ model_info ["created_by" ] = created_by
306368 model_info ["updated_at" ] = current_time .isoformat ()
307369
308370 # Build display name
309- display_name = model .get_display_name (apply_prefix = True )
371+ display_name = model_name_override or model .get_display_name (apply_prefix = True )
310372
311373 # Push to LiteLLM
312374 url = f"{ base_url .rstrip ('/' )} /model/new"
@@ -571,3 +633,88 @@ def _merge_pricing_fields(target: dict, source: dict) -> None:
571633 continue
572634 if "cost" in key or key == "tiered_pricing" :
573635 target [key ] = value
636+
637+
638+ async def push_routing_groups_to_litellm (session , config , group_id : int | None = None ) -> dict :
639+ """Push routing groups to LiteLLM as model groups."""
640+ if not config .litellm_base_url :
641+ raise RuntimeError ("LiteLLM destination not configured" )
642+
643+ from shared .crud import get_routing_groups , get_routing_group , get_model_by_provider_and_name , get_provider_by_id
644+
645+ if group_id is None :
646+ groups = await get_routing_groups (session )
647+ groups = [await get_routing_group (session , g .id ) for g in groups ]
648+ else :
649+ group = await get_routing_group (session , group_id )
650+ groups = [group ] if group else []
651+
652+ groups = [g for g in groups if g is not None ]
653+ stats = {"groups" : len (groups ), "added" : 0 , "deleted" : 0 , "missing_models" : 0 , "errors" : 0 }
654+
655+ async with httpx .AsyncClient () as client :
656+ litellm_models = await fetch_litellm_models (client , config .litellm_base_url , config .litellm_api_key )
657+
658+ for group in groups :
659+ group_tag = f"routing_group:{ group .name } "
660+ group_tag_lower = group_tag .lower ()
661+
662+ for m in litellm_models :
663+ tags = m .get ("litellm_params" , {}).get ("tags" , [])
664+ model_info_tags = m .get ("model_info" , {}).get ("tags" , [])
665+ root_tags = m .get ("tags" , [])
666+ combined_tags = [str (t ).lower () for t in (tags or []) + (model_info_tags or []) + (root_tags or [])]
667+ if group_tag_lower not in combined_tags :
668+ continue
669+ if m .get ("model_info" , {}).get ("created_by" ) != "routing_group" :
670+ continue
671+ model_id = m .get ("model_info" , {}).get ("id" )
672+ if not model_id :
673+ continue
674+ try :
675+ await delete_model_from_litellm (
676+ client ,
677+ config .litellm_base_url ,
678+ config .litellm_api_key ,
679+ model_id ,
680+ )
681+ stats ["deleted" ] += 1
682+ except Exception as exc :
683+ stats ["errors" ] += 1
684+ logger .warning ("Failed deleting routing group entry %s: %s" , model_id , exc )
685+
686+ for target in sorted (group .targets , key = lambda t : (t .priority , t .id )):
687+ provider = target .provider or await get_provider_by_id (session , target .provider_id )
688+ if not provider :
689+ stats ["missing_models" ] += 1
690+ continue
691+ model = await get_model_by_provider_and_name (session , provider .id , target .model_id )
692+ if not model :
693+ stats ["missing_models" ] += 1
694+ continue
695+ try :
696+ await push_model_to_litellm (
697+ client ,
698+ config .litellm_base_url ,
699+ config .litellm_api_key ,
700+ provider ,
701+ model ,
702+ config = config ,
703+ session = session ,
704+ model_name_override = group .name ,
705+ extra_tags = [group_tag ],
706+ created_by = "routing_group" ,
707+ strip_unique_id = True ,
708+ )
709+ stats ["added" ] += 1
710+ except Exception as exc :
711+ stats ["errors" ] += 1
712+ logger .warning (
713+ "Failed pushing routing target %s/%s for group %s: %s" ,
714+ provider .name ,
715+ model .model_id ,
716+ group .name ,
717+ exc ,
718+ )
719+
720+ return stats
0 commit comments