Skip to content

Commit 4b5dee5

Browse files
Add routing group LiteLLM status view
1 parent 7e86c6c commit 4b5dee5

File tree

14 files changed

+1028
-129
lines changed

14 files changed

+1028
-129
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Add provider max_requests_per_hour.
2+
3+
Revision ID: 007
4+
Revises: 006
5+
Create Date: 2026-01-05 00:00:00.000000
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = "007"
16+
down_revision: Union[str, None] = "006"
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
op.add_column("providers", sa.Column("max_requests_per_hour", sa.Integer, nullable=True))
23+
24+
25+
def downgrade() -> None:
26+
# SQLite doesn't support DROP COLUMN; leave as-is.
27+
pass

backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Backend sync worker service."""
22

3-
__version__ = "0.6.21"
3+
__version__ = "0.6.24"

backend/litellm_client.py

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from shared.models import ModelMetadata
66
from shared.pricing_profiles import apply_pricing_overrides
77
from shared.sources import _make_auth_headers, DEFAULT_TIMEOUT
8-
from shared.tags import generate_model_tags
8+
from shared.tags import generate_model_tags, normalize_tags
99

1010
logger = logging.getLogger(__name__)
1111

@@ -179,6 +179,57 @@ async def fetch_litellm_models(client: httpx.AsyncClient, base_url: str, api_key
179179
raise
180180

181181

182+
def _collect_litellm_tags(model: dict) -> list[str]:
183+
"""Collect tags across LiteLLM payload fields."""
184+
tags = model.get("litellm_params", {}).get("tags", [])
185+
model_info_tags = model.get("model_info", {}).get("tags", [])
186+
root_tags = model.get("tags", [])
187+
combined = list(tags or []) + list(model_info_tags or []) + list(root_tags or [])
188+
return [str(tag).lower() for tag in combined if tag is not None]
189+
190+
191+
def _extract_tag_value(tags: list[str], prefix: str) -> str | None:
192+
"""Extract the first tag value for a prefix."""
193+
for tag in tags:
194+
if tag.startswith(prefix):
195+
return tag[len(prefix):]
196+
return None
197+
198+
199+
async def list_routing_group_deployments(config) -> list[dict]:
200+
"""Return LiteLLM models tagged as routing groups."""
201+
if not config.litellm_base_url:
202+
return []
203+
204+
entries: list[dict] = []
205+
async with httpx.AsyncClient() as client:
206+
litellm_models = await fetch_litellm_models(
207+
client,
208+
config.litellm_base_url,
209+
config.litellm_api_key,
210+
)
211+
212+
for model in litellm_models:
213+
tags = _collect_litellm_tags(model)
214+
group_tag = next((tag for tag in tags if tag.startswith("routing_group:")), None)
215+
if not group_tag:
216+
continue
217+
group_name = group_tag.split(":", 1)[1]
218+
entries.append(
219+
{
220+
"group": group_name,
221+
"provider": _extract_tag_value(tags, "provider:") or "",
222+
"model_id": _extract_tag_value(tags, "model:") or "",
223+
"model_name": model.get("model_name"),
224+
"model_info_id": model.get("model_info", {}).get("id"),
225+
"created_by": model.get("model_info", {}).get("created_by"),
226+
"tags": tags,
227+
}
228+
)
229+
230+
return entries
231+
232+
182233
async def push_model_to_litellm(
183234
client: httpx.AsyncClient,
184235
base_url: str,
@@ -187,6 +238,10 @@ async def push_model_to_litellm(
187238
model,
188239
config=None,
189240
session=None,
241+
model_name_override: str | None = None,
242+
extra_tags: list[str] | None = None,
243+
created_by: str = "updater",
244+
strip_unique_id: bool = False,
190245
):
191246
"""Push a single model to LiteLLM."""
192247
# Build litellm_params
@@ -291,6 +346,13 @@ async def push_model_to_litellm(
291346
tags = [t for t in tags if t != "capability:completion"]
292347
tags.append("mode:chat")
293348

349+
if strip_unique_id:
350+
tags = [t for t in tags if not t.startswith("unique_id:")]
351+
352+
if extra_tags:
353+
tags.extend(normalize_tags(extra_tags))
354+
tags = normalize_tags(tags)
355+
294356
litellm_params["tags"] = tags
295357
model_info["tags"] = tags
296358

@@ -302,11 +364,11 @@ async def push_model_to_litellm(
302364
# Mark as created/updated by updater with timestamp
303365
from datetime import datetime, UTC
304366
current_time = datetime.now(UTC)
305-
model_info["created_by"] = "updater"
367+
model_info["created_by"] = created_by
306368
model_info["updated_at"] = current_time.isoformat()
307369

308370
# Build display name
309-
display_name = model.get_display_name(apply_prefix=True)
371+
display_name = model_name_override or model.get_display_name(apply_prefix=True)
310372

311373
# Push to LiteLLM
312374
url = f"{base_url.rstrip('/')}/model/new"
@@ -571,3 +633,88 @@ def _merge_pricing_fields(target: dict, source: dict) -> None:
571633
continue
572634
if "cost" in key or key == "tiered_pricing":
573635
target[key] = value
636+
637+
638+
async def push_routing_groups_to_litellm(session, config, group_id: int | None = None) -> dict:
639+
"""Push routing groups to LiteLLM as model groups."""
640+
if not config.litellm_base_url:
641+
raise RuntimeError("LiteLLM destination not configured")
642+
643+
from shared.crud import get_routing_groups, get_routing_group, get_model_by_provider_and_name, get_provider_by_id
644+
645+
if group_id is None:
646+
groups = await get_routing_groups(session)
647+
groups = [await get_routing_group(session, g.id) for g in groups]
648+
else:
649+
group = await get_routing_group(session, group_id)
650+
groups = [group] if group else []
651+
652+
groups = [g for g in groups if g is not None]
653+
stats = {"groups": len(groups), "added": 0, "deleted": 0, "missing_models": 0, "errors": 0}
654+
655+
async with httpx.AsyncClient() as client:
656+
litellm_models = await fetch_litellm_models(client, config.litellm_base_url, config.litellm_api_key)
657+
658+
for group in groups:
659+
group_tag = f"routing_group:{group.name}"
660+
group_tag_lower = group_tag.lower()
661+
662+
for m in litellm_models:
663+
tags = m.get("litellm_params", {}).get("tags", [])
664+
model_info_tags = m.get("model_info", {}).get("tags", [])
665+
root_tags = m.get("tags", [])
666+
combined_tags = [str(t).lower() for t in (tags or []) + (model_info_tags or []) + (root_tags or [])]
667+
if group_tag_lower not in combined_tags:
668+
continue
669+
if m.get("model_info", {}).get("created_by") != "routing_group":
670+
continue
671+
model_id = m.get("model_info", {}).get("id")
672+
if not model_id:
673+
continue
674+
try:
675+
await delete_model_from_litellm(
676+
client,
677+
config.litellm_base_url,
678+
config.litellm_api_key,
679+
model_id,
680+
)
681+
stats["deleted"] += 1
682+
except Exception as exc:
683+
stats["errors"] += 1
684+
logger.warning("Failed deleting routing group entry %s: %s", model_id, exc)
685+
686+
for target in sorted(group.targets, key=lambda t: (t.priority, t.id)):
687+
provider = target.provider or await get_provider_by_id(session, target.provider_id)
688+
if not provider:
689+
stats["missing_models"] += 1
690+
continue
691+
model = await get_model_by_provider_and_name(session, provider.id, target.model_id)
692+
if not model:
693+
stats["missing_models"] += 1
694+
continue
695+
try:
696+
await push_model_to_litellm(
697+
client,
698+
config.litellm_base_url,
699+
config.litellm_api_key,
700+
provider,
701+
model,
702+
config=config,
703+
session=session,
704+
model_name_override=group.name,
705+
extra_tags=[group_tag],
706+
created_by="routing_group",
707+
strip_unique_id=True,
708+
)
709+
stats["added"] += 1
710+
except Exception as exc:
711+
stats["errors"] += 1
712+
logger.warning(
713+
"Failed pushing routing target %s/%s for group %s: %s",
714+
provider.name,
715+
model.model_id,
716+
group.name,
717+
exc,
718+
)
719+
720+
return stats

frontend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Frontend API and UI service."""
22

3-
__version__ = "0.6.21"
3+
__version__ = "0.6.24"

frontend/routes/providers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ def _parse_pricing_override(input_cost: str | None, output_cost: str | None) ->
6262
return pricing or None
6363

6464

65+
def _parse_optional_int(raw: str | int | None) -> int | None:
66+
"""Parse optional integer inputs."""
67+
if raw in (None, ""):
68+
return None
69+
try:
70+
return int(raw)
71+
except (TypeError, ValueError):
72+
return None
73+
74+
6575
@router.get("")
6676
@router.get("/")
6777
async def list_providers(session: AsyncSession = Depends(get_session)):
@@ -83,6 +93,7 @@ async def list_providers(session: AsyncSession = Depends(get_session)):
8393
"access_groups": p.access_groups_list,
8494
"sync_enabled": p.sync_enabled,
8595
"sync_interval_seconds": p.sync_interval_seconds,
96+
"max_requests_per_hour": p.max_requests_per_hour,
8697
"pricing_profile": p.pricing_profile,
8798
"pricing_override": p.pricing_override_dict,
8899
"created_at": p.created_at.isoformat(),
@@ -308,6 +319,7 @@ async def add_provider(
308319
access_groups: str | None = Form(None),
309320
sync_enabled: bool | None = Form(True),
310321
sync_interval_seconds: int | None = Form(None),
322+
max_requests_per_hour: str | None = Form(None),
311323
auto_detect_fim: bool | None = Form(True),
312324
pricing_profile: str | None = Form(None),
313325
pricing_input_cost_per_token: str | None = Form(None),
@@ -335,6 +347,7 @@ async def add_provider(
335347
access_groups=_parse_csv_list(access_groups),
336348
sync_enabled=sync_enabled_val,
337349
sync_interval_seconds=sync_interval_seconds,
350+
max_requests_per_hour=_parse_optional_int(max_requests_per_hour),
338351
auto_detect_fim=auto_detect_fim_val,
339352
pricing_profile=_normalize_optional_str(pricing_profile),
340353
pricing_override=_parse_pricing_override(
@@ -402,6 +415,7 @@ async def update_provider_endpoint(
402415
access_groups: str | None = Form(None),
403416
sync_enabled: bool | None = Form(None),
404417
sync_interval_seconds: int | None = Form(None),
418+
max_requests_per_hour: str | None = Form(None),
405419
auto_detect_fim: bool | None = Form(None),
406420
pricing_profile: str | None = Form(None),
407421
pricing_input_cost_per_token: str | None = Form(None),
@@ -428,6 +442,7 @@ async def update_provider_endpoint(
428442
access_groups=_parse_csv_list(access_groups),
429443
sync_enabled=_parse_bool(sync_enabled),
430444
sync_interval_seconds=sync_interval_seconds,
445+
max_requests_per_hour=_parse_optional_int(max_requests_per_hour),
431446
auto_detect_fim=_parse_bool(auto_detect_fim),
432447
pricing_profile=_normalize_optional_str(pricing_profile),
433448
pricing_override=_parse_pricing_override(

0 commit comments

Comments
 (0)