Skip to content

Commit ce91a8e

Browse files
GWealecopybara-github
authored andcommitted
fix: Pass drop_params to LiteLLM completion API
This lets users to specify `drop_params` when initializing `LiteLlm`, which will be forwarded to LiteLLM's `acompletion` or `completion` calls Close #1718 Co-authored-by: George Weale <[email protected]> PiperOrigin-RevId: 828058105
1 parent d4c63fc commit ce91a8e

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,17 +866,20 @@ def __init__(self, model: str, **kwargs):
866866
model: The name of the LiteLlm model.
867867
**kwargs: Additional arguments to pass to the litellm completion api.
868868
"""
869+
drop_params = kwargs.pop("drop_params", None)
869870
super().__init__(model=model, **kwargs)
870871
# Warn if using Gemini via LiteLLM
871872
_warn_gemini_via_litellm(model)
872-
self._additional_args = kwargs
873+
self._additional_args = dict(kwargs)
873874
# preventing generation call with llm_client
874875
# and overriding messages, tools and stream which are managed internally
875876
self._additional_args.pop("llm_client", None)
876877
self._additional_args.pop("messages", None)
877878
self._additional_args.pop("tools", None)
878879
# public api called from runner determines to stream or not
879880
self._additional_args.pop("stream", None)
881+
if drop_params is not None:
882+
self._additional_args["drop_params"] = drop_params
880883

881884
async def generate_content_async(
882885
self, llm_request: LlmRequest, stream: bool = False

tests/unittests/models/test_litellm.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,23 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
15191519
assert kwargs["api_base"] == "some://url"
15201520

15211521

1522+
@pytest.mark.asyncio
1523+
async def test_acompletion_with_drop_params(mock_acompletion, mock_client):
1524+
lite_llm_instance = LiteLlm(
1525+
model="test_model", llm_client=mock_client, drop_params=True
1526+
)
1527+
1528+
async for _ in lite_llm_instance.generate_content_async(
1529+
LLM_REQUEST_WITH_FUNCTION_DECLARATION
1530+
):
1531+
pass
1532+
1533+
mock_acompletion.assert_called_once()
1534+
1535+
_, kwargs = mock_acompletion.call_args
1536+
assert kwargs["drop_params"] is True
1537+
1538+
15221539
@pytest.mark.asyncio
15231540
async def test_completion_additional_args(mock_completion, mock_client):
15241541
lite_llm_instance = LiteLlm(
@@ -1561,6 +1578,28 @@ async def test_completion_additional_args(mock_completion, mock_client):
15611578
assert kwargs["api_base"] == "some://url"
15621579

15631580

1581+
@pytest.mark.asyncio
1582+
async def test_completion_with_drop_params(mock_completion, mock_client):
1583+
lite_llm_instance = LiteLlm(
1584+
model="test_model", llm_client=mock_client, drop_params=True
1585+
)
1586+
1587+
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
1588+
1589+
responses = [
1590+
response
1591+
async for response in lite_llm_instance.generate_content_async(
1592+
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
1593+
)
1594+
]
1595+
assert len(responses) == 4
1596+
1597+
mock_completion.assert_called_once()
1598+
1599+
_, kwargs = mock_completion.call_args
1600+
assert kwargs["drop_params"] is True
1601+
1602+
15641603
@pytest.mark.asyncio
15651604
async def test_generate_content_async_stream(
15661605
mock_completion, lite_llm_instance

0 commit comments

Comments
 (0)