Skip to content

Commit c4ae26a

Browse files
authored
fix(platform): Also handle client_name inside _acompletion (#808)
1 parent 631709b commit c4ae26a

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/any_llm/providers/platform/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ async def _acompletion(
103103
params: CompletionParams,
104104
**kwargs: Any,
105105
) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]:
106+
client_name = kwargs.pop("client_name", None)
107+
if self.client_name is None:
108+
self.client_name = client_name
109+
106110
start_time = time.perf_counter()
107111

108112
# List of providers that don't support stream_options and automatically return token usage.

tests/unit/providers/test_platform_provider.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,3 +1158,80 @@ async def test_usage_event_includes_version_header(
11581158

11591159
assert "User-Agent" in headers
11601160
assert headers["User-Agent"] == f"python-any-llm/{__version__}"
1161+
1162+
1163+
@pytest.mark.asyncio
1164+
@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key")
1165+
@patch("any_llm.providers.platform.platform.post_completion_usage_event")
1166+
async def test_acompletion_handles_client_name_in_kwargs(
1167+
mock_post_usage: AsyncMock,
1168+
mock_get_decrypted_provider_key: Mock,
1169+
any_llm_key: str,
1170+
mock_decrypted_provider_key: DecryptedProviderKey,
1171+
mock_completion: ChatCompletion,
1172+
) -> None:
1173+
"""Test that _acompletion correctly handles client_name passed in kwargs."""
1174+
mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key
1175+
1176+
provider_instance = PlatformProvider(api_key=any_llm_key)
1177+
provider_instance.provider = OpenaiProvider
1178+
provider_instance.provider._acompletion = AsyncMock(return_value=mock_completion) # type: ignore[method-assign]
1179+
1180+
# Create completion params
1181+
params = CompletionParams(
1182+
model_id="gpt-4",
1183+
messages=[{"role": "user", "content": "Hello"}],
1184+
stream=False,
1185+
)
1186+
1187+
client_name = "test-client-from-kwargs"
1188+
1189+
# Call _acompletion with client_name in kwargs
1190+
await provider_instance._acompletion(params, client_name=client_name)
1191+
1192+
# Verify self.client_name was updated
1193+
assert provider_instance.client_name == client_name
1194+
1195+
# Verify post_completion_usage_event was called with the correct client_name
1196+
mock_post_usage.assert_called_once()
1197+
call_args = mock_post_usage.call_args
1198+
assert call_args.kwargs["client_name"] == client_name
1199+
1200+
1201+
@pytest.mark.asyncio
1202+
@patch("any_llm_platform_client.AnyLLMPlatformClient.get_decrypted_provider_key")
1203+
@patch("any_llm.providers.platform.platform.post_completion_usage_event")
1204+
async def test_acompletion_does_not_overwrite_existing_client_name(
1205+
mock_post_usage: AsyncMock,
1206+
mock_get_decrypted_provider_key: Mock,
1207+
any_llm_key: str,
1208+
mock_decrypted_provider_key: DecryptedProviderKey,
1209+
mock_completion: ChatCompletion,
1210+
) -> None:
1211+
"""Test that _acompletion does not overwrite an existing client_name if one is already set."""
1212+
mock_get_decrypted_provider_key.return_value = mock_decrypted_provider_key
1213+
1214+
initial_client_name = "initial-client"
1215+
provider_instance = PlatformProvider(api_key=any_llm_key, client_name=initial_client_name)
1216+
provider_instance.provider = OpenaiProvider
1217+
provider_instance.provider._acompletion = AsyncMock(return_value=mock_completion) # type: ignore[method-assign]
1218+
1219+
# Create completion params
1220+
params = CompletionParams(
1221+
model_id="gpt-4",
1222+
messages=[{"role": "user", "content": "Hello"}],
1223+
stream=False,
1224+
)
1225+
1226+
new_client_name = "new-client-from-kwargs"
1227+
1228+
# Call _acompletion with a new client_name in kwargs
1229+
await provider_instance._acompletion(params, client_name=new_client_name)
1230+
1231+
# Verify self.client_name was NOT updated
1232+
assert provider_instance.client_name == initial_client_name
1233+
1234+
# Verify post_completion_usage_event was called with the INITIAL client_name
1235+
mock_post_usage.assert_called_once()
1236+
call_args = mock_post_usage.call_args
1237+
assert call_args.kwargs["client_name"] == initial_client_name

0 commit comments

Comments
 (0)