Skip to content

Commit 8f10805

Browse files
refactor(llm): move provider transport helpers into LLMProvider
Co-authored-by: openhands <openhands@all-hands.dev>
1 parent 50f6387 commit 8f10805

File tree

5 files changed

+117
-26
lines changed

5 files changed

+117
-26
lines changed

openhands-sdk/openhands/sdk/llm/llm.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -869,15 +869,15 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse:
869869
typed_input: ResponseInputParam | str = (
870870
cast(ResponseInputParam, input_items) if input_items else ""
871871
)
872-
api_key_value = self._get_litellm_api_key_value()
873872
provider_info = self._get_litellm_provider_info()
874873

875874
ret = litellm_responses(
876-
**provider_info.as_litellm_call_kwargs(),
875+
**provider_info.as_litellm_call_kwargs(
876+
api_key=self._get_api_key_value()
877+
),
877878
input=typed_input,
878879
instructions=instructions,
879880
tools=resp_tools,
880-
api_key=api_key_value,
881881
api_base=self.base_url,
882882
api_version=self.api_version,
883883
timeout=self.timeout,
@@ -985,28 +985,21 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse:
985985
# =========================================================================
986986

987987
def _get_litellm_provider_info(self) -> LLMProvider:
988-
cache_key = (self.model, self.base_url)
989-
if self._provider_info is None or self._provider_info_cache_key != cache_key:
990-
self._provider_info = LLMProvider.from_model(
991-
model=self.model, api_base=self.base_url
992-
)
993-
self._provider_info_cache_key = cache_key
988+
self._provider_info, self._provider_info_cache_key = LLMProvider.resolve_cached(
989+
model=self.model,
990+
api_base=self.base_url,
991+
cached_provider=self._provider_info,
992+
cached_key=self._provider_info_cache_key,
993+
)
994+
assert self._provider_info is not None
994995
return self._provider_info
995996

996-
def _get_litellm_api_key_value(self) -> str | None:
997-
api_key_value: str | None = None
998-
if self.api_key:
999-
assert isinstance(self.api_key, SecretStr)
1000-
api_key_value = self.api_key.get_secret_value()
1001-
1002-
# LiteLLM treats api_key for Bedrock as an AWS bearer token.
1003-
# Passing a non-Bedrock key (e.g. OpenAI/Anthropic) can cause Bedrock
1004-
# to reject the request with an "Invalid API Key format" error.
1005-
# For IAM/SigV4 auth (the default Bedrock path), do not forward api_key.
1006-
if api_key_value is not None and self._get_litellm_provider_info().is_bedrock:
997+
def _get_api_key_value(self) -> str | None:
998+
if self.api_key is None:
1007999
return None
10081000

1009-
return api_key_value
1001+
assert isinstance(self.api_key, SecretStr)
1002+
return self.api_key.get_secret_value()
10101003

10111004
def _transport_call(
10121005
self,
@@ -1041,13 +1034,13 @@ def _transport_call(
10411034
category=DeprecationWarning,
10421035
message="Accessing the 'model_fields' attribute.*",
10431036
)
1044-
api_key_value = self._get_litellm_api_key_value()
10451037
provider_info = self._get_litellm_provider_info()
10461038

10471039
# Some providers need renames handled in _normalize_call_kwargs.
10481040
ret = litellm_completion(
1049-
**provider_info.as_litellm_call_kwargs(),
1050-
api_key=api_key_value,
1041+
**provider_info.as_litellm_call_kwargs(
1042+
api_key=self._get_api_key_value()
1043+
),
10511044
api_base=self.base_url,
10521045
api_version=self.api_version,
10531046
timeout=self.timeout,

openhands-sdk/openhands/sdk/llm/utils/litellm_provider.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ class LLMProvider:
2525
name: str | None
2626
resolved_api_base: str | None
2727

28+
@staticmethod
29+
def cache_key(*, model: str, api_base: str | None) -> tuple[str, str | None]:
30+
return (model, api_base)
31+
32+
@classmethod
33+
def resolve_cached(
34+
cls,
35+
*,
36+
model: str,
37+
api_base: str | None,
38+
cached_provider: LLMProvider | None,
39+
cached_key: tuple[str, str | None] | None,
40+
) -> tuple[LLMProvider, tuple[str, str | None]]:
41+
cache_key = cls.cache_key(model=model, api_base=api_base)
42+
if cached_provider is not None and cached_key == cache_key:
43+
return cached_provider, cache_key
44+
return cls.from_model(model=model, api_base=api_base), cache_key
45+
2846
@classmethod
2947
def from_model(cls, *, model: str, api_base: str | None) -> LLMProvider:
3048
"""Parse a model string using LiteLLM's provider inference logic."""
@@ -81,10 +99,22 @@ def canonical_name(self) -> str:
8199
def is_bedrock(self) -> bool:
82100
return self.name == "bedrock"
83101

84-
def as_litellm_call_kwargs(self) -> dict[str, str]:
102+
def api_key_for_litellm(self, api_key: str | None) -> str | None:
103+
# LiteLLM treats api_key for Bedrock as an AWS bearer token.
104+
# Passing a non-Bedrock key (e.g. OpenAI/Anthropic) can cause Bedrock
105+
# to reject the request with an "Invalid API Key format" error.
106+
# For IAM/SigV4 auth (the default Bedrock path), do not forward api_key.
107+
if api_key is not None and self.is_bedrock:
108+
return None
109+
return api_key
110+
111+
def as_litellm_call_kwargs(self, *, api_key: str | None = None) -> dict[str, str]:
85112
kwargs = {"model": self.model}
86113
if self.name is not None:
87114
kwargs["custom_llm_provider"] = self.name
115+
normalized_api_key = self.api_key_for_litellm(api_key)
116+
if normalized_api_key is not None:
117+
kwargs["api_key"] = normalized_api_key
88118
return kwargs
89119

90120
def infer_api_base(self) -> str | None:

tests/sdk/llm/test_api_key_validation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def test_bedrock_model_with_api_key_not_forwarded_to_litellm():
8585
api_key=SecretStr("sk-ant-not-a-bedrock-key"),
8686
)
8787
assert llm.api_key is not None
88-
assert llm._get_litellm_api_key_value() is None
88+
assert isinstance(llm.api_key, SecretStr)
89+
provider = llm._get_litellm_provider_info()
90+
assert provider.api_key_for_litellm(llm.api_key.get_secret_value()) is None
8991

9092

9193
def test_non_bedrock_model_with_valid_key():

tests/sdk/llm/test_litellm_provider.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,47 @@ def test_llm_provider_parses_bedrock_model():
3030
)
3131

3232

33+
def test_llm_provider_strips_api_key_for_bedrock_calls():
34+
provider = LLMProvider.from_model(
35+
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
36+
api_base=None,
37+
)
38+
39+
assert provider.api_key_for_litellm("sk-ant-not-a-bedrock-key") is None
40+
assert provider.as_litellm_call_kwargs(api_key="sk-ant-not-a-bedrock-key") == {
41+
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
42+
"custom_llm_provider": "bedrock",
43+
}
44+
45+
46+
def test_llm_provider_reuses_cached_instance_for_same_request():
47+
provider = LLMProvider.from_model(model="gpt-4o", api_base=None)
48+
49+
cached_provider, cached_key = LLMProvider.resolve_cached(
50+
model="gpt-4o",
51+
api_base=None,
52+
cached_provider=provider,
53+
cached_key=LLMProvider.cache_key(model="gpt-4o", api_base=None),
54+
)
55+
56+
assert cached_provider is provider
57+
assert cached_key == ("gpt-4o", None)
58+
59+
60+
def test_llm_provider_refreshes_cached_instance_when_request_changes():
61+
provider = LLMProvider.from_model(model="gpt-4o", api_base=None)
62+
63+
refreshed_provider, refreshed_key = LLMProvider.resolve_cached(
64+
model="proxy/test-renamed-model",
65+
api_base="http://localhost:8000",
66+
cached_provider=provider,
67+
cached_key=LLMProvider.cache_key(model="gpt-4o", api_base=None),
68+
)
69+
70+
assert refreshed_provider is not provider
71+
assert refreshed_key == ("proxy/test-renamed-model", "http://localhost:8000")
72+
73+
3374
def test_llm_provider_handles_unknown_model_without_provider():
3475
provider = LLMProvider.from_model(model="unknown-model", api_base=None)
3576

tests/sdk/llm/test_llm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,31 @@ def test_llm_responses_forwards_extra_headers_to_litellm(mock_responses):
390390
assert kwargs.get("extra_headers") == headers
391391

392392

393+
@patch("openhands.sdk.llm.llm.litellm_completion")
394+
def test_llm_completion_does_not_forward_bedrock_api_key(mock_completion):
395+
mock_response = create_mock_litellm_response("ok")
396+
mock_completion.return_value = mock_response
397+
398+
llm = LLM(
399+
usage_id="test-llm",
400+
model="us.anthropic.claude-3-sonnet-20240229-v1:0",
401+
api_key=SecretStr("sk-ant-not-a-bedrock-key"),
402+
num_retries=0,
403+
)
404+
405+
provider_info = llm._get_litellm_provider_info()
406+
407+
messages = [Message(role="user", content=[TextContent(text="Hi")])]
408+
_ = llm.completion(messages=messages)
409+
410+
assert mock_completion.call_count == 1
411+
_, kwargs = mock_completion.call_args
412+
assert kwargs["model"] == provider_info.model
413+
if provider_info.name is not None:
414+
assert kwargs["custom_llm_provider"] == provider_info.name
415+
assert "api_key" not in kwargs
416+
417+
393418
@patch("openhands.sdk.llm.llm.litellm_completion")
394419
def test_llm_model_copy_recomputes_transport_provider_for_proxy_alias(mock_completion):
395420
mock_response = create_mock_litellm_response("ok")

0 commit comments

Comments
 (0)