|
15 | 15 |
|
16 | 16 | import yaml |
17 | 17 | from langchain_litellm import ChatLiteLLM |
| 18 | +from litellm import get_model_info |
18 | 19 | from sqlalchemy import select |
19 | 20 | from sqlalchemy.ext.asyncio import AsyncSession |
20 | 21 |
|
|
62 | 63 | } |
63 | 64 |
|
64 | 65 |
|
| 66 | +def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: |
| 67 | + """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" |
| 68 | + try: |
| 69 | + info = get_model_info(model_string) |
| 70 | + max_input_tokens = info.get("max_input_tokens") |
| 71 | + if isinstance(max_input_tokens, int) and max_input_tokens > 0: |
| 72 | + llm.profile = { |
| 73 | + "max_input_tokens": max_input_tokens, |
| 74 | + "max_input_tokens_upper": max_input_tokens, |
| 75 | + "token_count_model": model_string, |
| 76 | + "token_count_models": [model_string], |
| 77 | + } |
| 78 | + except Exception: |
| 79 | + return |
| 80 | + |
| 81 | + |
65 | 82 | @dataclass |
66 | 83 | class AgentConfig: |
67 | 84 | """ |
@@ -366,7 +383,9 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: |
366 | 383 | if llm_config.get("litellm_params"): |
367 | 384 | litellm_kwargs.update(llm_config["litellm_params"]) |
368 | 385 |
|
369 | | - return ChatLiteLLM(**litellm_kwargs) |
| 386 | + llm = ChatLiteLLM(**litellm_kwargs) |
| 387 | + _attach_model_profile(llm, model_string) |
| 388 | + return llm |
370 | 389 |
|
371 | 390 |
|
372 | 391 | def create_chat_litellm_from_agent_config( |
@@ -419,4 +438,6 @@ def create_chat_litellm_from_agent_config( |
419 | 438 | if agent_config.litellm_params: |
420 | 439 | litellm_kwargs.update(agent_config.litellm_params) |
421 | 440 |
|
422 | | - return ChatLiteLLM(**litellm_kwargs) |
| 441 | + llm = ChatLiteLLM(**litellm_kwargs) |
| 442 | + _attach_model_profile(llm, model_string) |
| 443 | + return llm |
0 commit comments