Skip to content

Commit ccabcba

Browse files
authored
feat: Add any-llm-args to config. (#905)
* feat: Add `any-llm-args` to config. * fix lint
1 parent a3b46d5 commit ccabcba

File tree

7 files changed

+112
-15
lines changed

7 files changed

+112
-15
lines changed

src/any_agent/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ class CalendarEvent(BaseModel):
236236
Refer to [any-llm Completion API Docs](https://mozilla-ai.github.io/any-llm/api/completion/) for more info.
237237
"""
238238

239+
any_llm_args: MutableMapping[str, Any] | None = None
240+
"""Pass arguments to `AnyLLM.create()` when using integrations backed by any-llm.
241+
242+
Use this for provider/client initialization options that are not completion-time
243+
generation params (which should be passed via `model_args`).
244+
"""
245+
239246
output_type: type[BaseModel] | None = None
240247
"""Control the output schema from calling `run`. By default, the agent will return a type str.
241248

src/any_agent/frameworks/llama_index.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
max_retries: int = 10,
103103
api_key: str | None = None,
104104
api_base: str | None = None,
105+
any_llm_args: dict[str, Any] | None = None,
105106
**kwargs: Any,
106107
) -> None:
107108
additional_kwargs = additional_kwargs or {}
@@ -116,10 +117,14 @@ def __init__(
116117
)
117118

118119
self._parse_model(model)
120+
llm_create_kwargs: dict[str, Any] = dict(any_llm_args or {})
121+
if api_key is not None:
122+
llm_create_kwargs["api_key"] = api_key
123+
if api_base is not None:
124+
llm_create_kwargs["api_base"] = api_base
119125
self._client = AnyLLM.create(
120126
provider=self._provider,
121-
api_key=api_key,
122-
api_base=api_base,
127+
**llm_create_kwargs,
123128
)
124129

125130
def _parse_model(self, model: str) -> None:
@@ -512,14 +517,18 @@ def _get_model(self, agent_config: AgentConfig) -> "LLM":
512517

513518
model_id = agent_config.model_id
514519

520+
model_kwargs: dict[str, Any] = {
521+
"model": model_id,
522+
"api_key": agent_config.api_key,
523+
"api_base": agent_config.api_base,
524+
"additional_kwargs": additional_kwargs,
525+
}
526+
if model_type is DEFAULT_MODEL_TYPE and agent_config.any_llm_args is not None:
527+
model_kwargs["any_llm_args"] = agent_config.any_llm_args
528+
515529
return cast(
516530
"LLM",
517-
model_type(
518-
model=model_id,
519-
api_key=agent_config.api_key,
520-
api_base=agent_config.api_base,
521-
additional_kwargs=additional_kwargs, # type: ignore[arg-type]
522-
),
531+
model_type(**model_kwargs),
523532
)
524533

525534
async def _load_agent(self) -> None:

src/any_agent/frameworks/openai.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,19 @@ def __init__(
8080
model: str,
8181
base_url: str | None = None,
8282
api_key: str | None = None,
83+
any_llm_args: dict[str, Any] | None = None,
8384
):
8485
provider, model_id = AnyLLM.split_model_provider(model)
8586
self.model = model
8687
self.base_url = base_url
8788
self.api_key = api_key
88-
self.llm = AnyLLM.create(provider=provider, api_key=api_key, api_base=base_url)
89+
llm_create_kwargs: dict[str, Any] = dict(any_llm_args or {})
90+
if api_key is not None:
91+
llm_create_kwargs["api_key"] = api_key
92+
if base_url is not None:
93+
llm_create_kwargs["api_base"] = base_url
94+
95+
self.llm = AnyLLM.create(provider=provider, **llm_create_kwargs)
8996
self.model_id = model_id
9097

9198
async def get_response(
@@ -399,11 +406,14 @@ def _get_model(
399406
base_url = agent_config.api_base or cast(
400407
"str | None", model_args.get("api_base")
401408
)
402-
return model_type(
403-
model=agent_config.model_id,
404-
base_url=base_url,
405-
api_key=agent_config.api_key,
406-
)
409+
model_kwargs: dict[str, Any] = {
410+
"model": agent_config.model_id,
411+
"base_url": base_url,
412+
"api_key": agent_config.api_key,
413+
}
414+
if model_type is DEFAULT_MODEL_TYPE and agent_config.any_llm_args is not None:
415+
model_kwargs["any_llm_args"] = agent_config.any_llm_args
416+
return model_type(**model_kwargs)
407417

408418
async def _load_agent(self) -> None:
409419
"""Load the OpenAI agent with the given configuration."""

src/any_agent/frameworks/tinyagent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(self, config: AgentConfig) -> None:
101101
self.uses_openai = provider_name == LLMProvider.OPENAI
102102

103103
# Create the LLM instance using the AnyLLM class pattern
104-
llm_kwargs: dict[str, Any] = {}
104+
llm_kwargs: dict[str, Any] = dict(self.config.any_llm_args or {})
105105
if self.config.api_key:
106106
llm_kwargs["api_key"] = self.config.api_key
107107
if self.config.api_base:

tests/unit/frameworks/test_llama_index.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,34 @@ def test_run_llama_index_agent_custom_args() -> None:
7070
)
7171
agent.run("foo", timeout=10)
7272
agent_mock.run.assert_called_once_with("foo", timeout=10)
73+
74+
75+
def test_load_llama_index_agent_forwards_any_llm_args() -> None:
76+
model_mock = MagicMock()
77+
create_mock = MagicMock()
78+
create_mock.return_value = MagicMock()
79+
any_llm_args = {"timeout": 17, "max_retries": 3}
80+
81+
from llama_index.core.tools import FunctionTool
82+
83+
with (
84+
patch("any_agent.frameworks.llama_index.DEFAULT_AGENT_TYPE", create_mock),
85+
patch("any_agent.frameworks.llama_index.DEFAULT_MODEL_TYPE", model_mock),
86+
patch.object(FunctionTool, "from_defaults"),
87+
):
88+
AnyAgent.create(
89+
AgentFramework.LLAMA_INDEX,
90+
AgentConfig(
91+
model_id="gemini/gemini-2.0-flash",
92+
instructions="You are a helpful assistant",
93+
any_llm_args=any_llm_args,
94+
),
95+
)
96+
97+
model_mock.assert_called_once_with(
98+
model="gemini/gemini-2.0-flash",
99+
api_key=None,
100+
api_base=None,
101+
additional_kwargs={},
102+
any_llm_args=any_llm_args,
103+
)

tests/unit/frameworks/test_openai.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,31 @@ def test_openai_with_api_key() -> None:
8484
)
8585

8686

87+
def test_openai_forwards_any_llm_args() -> None:
88+
mock_agent = MagicMock()
89+
mock_model = MagicMock()
90+
any_llm_args = {"timeout": 42, "headers": {"x-test": "1"}}
91+
92+
with (
93+
patch("any_agent.frameworks.openai.Agent", mock_agent),
94+
patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_model),
95+
):
96+
AnyAgent.create(
97+
AgentFramework.OPENAI,
98+
AgentConfig(
99+
model_id="mistral:mistral-small-latest",
100+
any_llm_args=any_llm_args,
101+
),
102+
)
103+
104+
mock_model.assert_called_once_with(
105+
model="mistral:mistral-small-latest",
106+
base_url=None,
107+
api_key=None,
108+
any_llm_args=any_llm_args,
109+
)
110+
111+
87112
def test_load_openai_with_mcp_server() -> None:
88113
mock_agent = MagicMock()
89114
mock_function_tool = MagicMock()

tests/unit/frameworks/test_tinyagent.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ def test_uses_openai_handles_gateway_provider(
268268
assert agent.uses_openai is expected_uses_openai
269269

270270

271+
def test_tinyagent_forwards_any_llm_args_to_anyllm_create() -> None:
272+
any_llm_args = {"timeout": 99, "organization": "test-org"}
273+
provider, _ = AnyLLM.split_model_provider(DEFAULT_SMALL_MODEL_ID)
274+
275+
with patch("any_agent.frameworks.tinyagent.AnyLLM.create") as mock_create:
276+
TinyAgent(
277+
AgentConfig(
278+
model_id=DEFAULT_SMALL_MODEL_ID,
279+
any_llm_args=any_llm_args,
280+
)
281+
)
282+
283+
mock_create.assert_called_once_with(provider, **any_llm_args)
284+
285+
271286
@pytest.mark.asyncio
272287
async def test_tool_result_appended_when_tool_not_found() -> None:
273288
"""Test that tool_result message is appended when a tool is not found.

0 commit comments

Comments
 (0)