Skip to content

Commit 6c5a64b

Browse files
committed
feat: add secure passthrough header forwarding with forward_headers and extra_blocked_headers
1 parent 2157c09 commit 6c5a64b

File tree

14 files changed

+1200
-103
lines changed

14 files changed

+1200
-103
lines changed

docs/docs/providers/inference/remote_passthrough.mdx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ Passthrough inference provider for connecting to any external inference service
3535
| `network.timeout.read` | `float \| None` | No | | Read timeout in seconds. |
3636
| `network.headers` | `dict[str, str] \| None` | No | | Additional HTTP headers to include in all requests. |
3737
| `base_url` | `HttpUrl \| None` | No | | The URL for the passthrough endpoint |
38+
| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. Only listed keys are forwarded — all others are ignored (default-deny). Values are forwarded verbatim; include any required prefix in the client payload (e.g. 'Bearer sk-xxx' not 'sk-xxx' when targeting Authorization). Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. When this field is set and auth comes from forwarded headers rather than a static api_key, the caller must include the required keys in X-LlamaStack-Provider-Data on every request. Example: {"maas_api_token": "Authorization"} |
39+
| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. |
3840

3941
## Sample Configuration
4042

4143
```yaml
4244
base_url: ${env.PASSTHROUGH_URL}
43-
api_key: ${env.PASSTHROUGH_API_KEY}
45+
api_key: ${env.PASSTHROUGH_API_KEY:=}
4446
```

docs/docs/providers/safety/remote_passthrough.mdx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ Passthrough safety provider that forwards moderation calls to a downstream HTTP
1616
|-------|------|----------|---------|-------------|
1717
| `base_url` | `HttpUrl` | No | | Base URL of the downstream safety service (e.g. https://safety.example.com/v1) |
1818
| `api_key` | `SecretStr \| None` | No | | API key for the downstream safety service. If set, takes precedence over provider data. |
19-
| `forward_headers` | `dict[str, str]` | No | {} | Mapping of provider data keys to outbound HTTP header names. Only keys listed here are forwarded from X-LlamaStack-Provider-Data to the downstream service. Example: {"maas_api_token": "Authorization"} |
19+
| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of provider data keys to outbound HTTP header names. Only keys listed here are forwarded from X-LlamaStack-Provider-Data to the downstream service. Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. Example: {"maas_api_token": "Authorization"} |
20+
| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. |
2021

2122
## Sample Configuration
2223

src/llama_stack/providers/remote/inference/passthrough/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,29 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from pydantic import BaseModel, SecretStr
7+
from pydantic import BaseModel, ConfigDict, SecretStr
88

99
from .config import PassthroughImplConfig
1010

1111

1212
class PassthroughProviderDataValidator(BaseModel):
13-
passthrough_url: str
14-
passthrough_api_key: SecretStr
13+
# Lives here because the framework resolves provider_data_validator by module path,
14+
# and the registry entry points to this package root.
15+
#
16+
# extra="allow" because forward_headers key names (e.g. "maas_api_token") are
17+
# deployer-defined at config time — they can't be declared as typed fields.
18+
# Without it, Pydantic drops them before build_forwarded_headers() can read them.
19+
model_config = ConfigDict(extra="allow")
20+
21+
passthrough_url: str | None = None
22+
passthrough_api_key: SecretStr | None = None
1523

1624

1725
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
1826
from .passthrough import PassthroughInferenceAdapter
1927

20-
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
28+
if not isinstance(config, PassthroughImplConfig):
29+
raise ValueError(f"Unexpected config type: {type(config)}")
2130
impl = PassthroughInferenceAdapter(config)
2231
await impl.initialize()
2332
return impl

src/llama_stack/providers/remote/inference/passthrough/config.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from typing import Any
88

9-
from pydantic import Field, HttpUrl
9+
from pydantic import Field, HttpUrl, model_validator
1010

11+
from llama_stack.providers.utils.forward_headers import validate_forward_headers_config
1112
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
1213
from llama_stack_api import json_schema_type
1314

@@ -18,12 +19,50 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig):
1819
default=None,
1920
description="The URL for the passthrough endpoint",
2021
)
22+
forward_headers: dict[str, str] | None = Field(
23+
default=None,
24+
description=(
25+
"Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. "
26+
"Only listed keys are forwarded — all others are ignored (default-deny). "
27+
"Values are forwarded verbatim; include any required prefix in the client payload "
28+
"(e.g. 'Bearer sk-xxx' not 'sk-xxx' when targeting Authorization). "
29+
"Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). "
30+
"Keys with a __ prefix and core security-sensitive headers (for example Host, "
31+
"Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. "
32+
"When this field is set and auth comes from forwarded headers rather than a static api_key, "
33+
"the caller must include the required keys in X-LlamaStack-Provider-Data on every request. "
34+
'Example: {"maas_api_token": "Authorization"}'
35+
),
36+
)
37+
extra_blocked_headers: list[str] = Field(
38+
default_factory=list,
39+
description=(
40+
"Additional outbound header names to block in forward_headers. "
41+
"Names are matched case-insensitively and added to the core blocked list. "
42+
"This can tighten policy but cannot unblock core security-sensitive headers."
43+
),
44+
)
45+
46+
@model_validator(mode="after")
47+
def validate_forward_headers(self) -> "PassthroughImplConfig":
48+
validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers)
49+
return self
2150

2251
@classmethod
2352
def sample_run_config(
24-
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
53+
cls,
54+
base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}",
55+
api_key: str = "${env.PASSTHROUGH_API_KEY:=}",
56+
forward_headers: dict[str, str] | None = None,
57+
extra_blocked_headers: list[str] | None = None,
58+
**kwargs,
2559
) -> dict[str, Any]:
26-
return {
60+
config: dict[str, Any] = {
2761
"base_url": base_url,
2862
"api_key": api_key,
2963
}
64+
if forward_headers:
65+
config["forward_headers"] = forward_headers
66+
if extra_blocked_headers:
67+
config["extra_blocked_headers"] = extra_blocked_headers
68+
return config

src/llama_stack/providers/remote/inference/passthrough/passthrough.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from openai import AsyncOpenAI
1010

1111
from llama_stack.core.request_headers import NeedsRequestProviderData
12+
from llama_stack.log import get_logger
13+
from llama_stack.providers.utils.forward_headers import build_forwarded_headers
1214
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
1315
from llama_stack_api import (
1416
Inference,
@@ -24,6 +26,8 @@
2426

2527
from .config import PassthroughImplConfig
2628

29+
logger = get_logger(__name__, category="inference")
30+
2731

2832
class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
2933
def __init__(self, config: PassthroughImplConfig) -> None:
@@ -74,37 +78,70 @@ async def should_refresh_models(self) -> bool:
7478
def _get_openai_client(self) -> AsyncOpenAI:
7579
"""Get an AsyncOpenAI client configured for the downstream server."""
7680
base_url = self._get_passthrough_url()
77-
api_key = self._get_passthrough_api_key()
81+
request_headers = self._build_request_headers()
7882

83+
# api_key="" means the SDK adds no Authorization header of its own;
84+
# auth comes entirely from request_headers (forwarded or static api_key).
85+
# This avoids the "passthrough" sentinel that would send a spurious
86+
# Authorization: Bearer passthrough to every downstream, even when
87+
# forward_headers only targets non-auth headers like X-Tenant-ID.
7988
return AsyncOpenAI(
8089
base_url=f"{base_url.rstrip('/')}/v1",
81-
api_key=api_key,
90+
api_key="",
91+
default_headers=request_headers or None,
8292
)
8393

94+
def _build_request_headers(self) -> dict[str, str]:
95+
"""Build outbound headers: forwarded provider-data keys first, then static api_key.
96+
97+
Static api_key always wins over a forwarded Authorization, regardless of casing.
98+
"""
99+
provider_data = self.get_request_provider_data()
100+
headers = build_forwarded_headers(provider_data, self.config.forward_headers)
101+
if self.config.forward_headers and not headers:
102+
logger.warning(
103+
"forward_headers is configured but no matching keys found in provider data — "
104+
"outbound request may be unauthenticated"
105+
)
106+
api_key = self._get_passthrough_api_key_or_none(provider_data)
107+
if api_key:
108+
# remove any forwarded authorization variant (case-insensitive) so static key wins
109+
headers = {k: v for k, v in headers.items() if k.lower() != "authorization"}
110+
headers["Authorization"] = f"Bearer {api_key}"
111+
return headers
112+
113+
def _get_passthrough_api_key_or_none(self, provider_data: object | None = None) -> str | None:
114+
"""Return the static or per-request API key, or None if not configured."""
115+
if self.config.auth_credential is not None:
116+
configured_api_key = self.config.auth_credential.get_secret_value()
117+
if configured_api_key:
118+
return configured_api_key
119+
120+
if provider_data is None:
121+
provider_data = self.get_request_provider_data()
122+
passthrough_api_key = getattr(provider_data, "passthrough_api_key", None)
123+
if passthrough_api_key is not None:
124+
if hasattr(passthrough_api_key, "get_secret_value"):
125+
provider_data_api_key = passthrough_api_key.get_secret_value()
126+
else:
127+
provider_data_api_key = str(passthrough_api_key)
128+
if provider_data_api_key:
129+
return provider_data_api_key
130+
131+
return None
132+
84133
def _get_passthrough_url(self) -> str:
85134
"""Get the passthrough URL from config or provider data."""
86135
if self.config.base_url is not None:
87136
return str(self.config.base_url)
88137

89138
provider_data = self.get_request_provider_data()
90-
if provider_data is None:
139+
if provider_data is None or provider_data.passthrough_url is None:
91140
raise ValueError(
92141
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
93142
)
94143
return provider_data.passthrough_url
95144

96-
def _get_passthrough_api_key(self) -> str:
97-
"""Get the passthrough API key from config or provider data."""
98-
if self.config.auth_credential is not None:
99-
return self.config.auth_credential.get_secret_value()
100-
101-
provider_data = self.get_request_provider_data()
102-
if provider_data is None:
103-
raise ValueError(
104-
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
105-
)
106-
return provider_data.passthrough_api_key.get_secret_value()
107-
108145
async def openai_completion(
109146
self,
110147
params: OpenAICompletionRequestWithExtraBody,

src/llama_stack/providers/remote/safety/passthrough/config.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,16 @@
66

77
from typing import Any
88

9-
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, SecretStr, field_validator
9+
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, SecretStr, model_validator
1010

11+
from llama_stack.providers.utils.forward_headers import validate_forward_headers_config
1112
from llama_stack_api import json_schema_type
1213

13-
_BLOCKED_HEADERS = frozenset(
14-
{
15-
"host",
16-
"content-type",
17-
"content-length",
18-
"transfer-encoding",
19-
"connection",
20-
"upgrade",
21-
"te",
22-
"trailer",
23-
"cookie",
24-
"set-cookie",
25-
}
26-
)
27-
2814

2915
class PassthroughProviderDataValidator(BaseModel):
30-
# allow arbitrary keys so forward_headers can access them
16+
# extra="allow" because forward_headers key names (e.g. "maas_api_token") are
17+
# deployer-defined at config time — they can't be declared as typed fields.
18+
# Without it, Pydantic drops them before build_forwarded_headers() can read them.
3119
model_config = ConfigDict(extra="allow")
3220

3321
passthrough_api_key: SecretStr | None = Field(
@@ -46,34 +34,37 @@ class PassthroughSafetyConfig(BaseModel):
4634
default=None,
4735
description="API key for the downstream safety service. If set, takes precedence over provider data.",
4836
)
49-
forward_headers: dict[str, str] = Field(
50-
default_factory=dict,
37+
forward_headers: dict[str, str] | None = Field(
38+
default=None,
5139
description=(
5240
"Mapping of provider data keys to outbound HTTP header names. "
53-
"Only keys listed here are forwarded from X-LlamaStack-Provider-Data "
54-
'to the downstream service. Example: {"maas_api_token": "Authorization"}'
41+
"Only keys listed here are forwarded from X-LlamaStack-Provider-Data to the downstream service. "
42+
"Keys with a __ prefix and core security-sensitive headers (for example Host, "
43+
"Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. "
44+
'Example: {"maas_api_token": "Authorization"}'
45+
),
46+
)
47+
extra_blocked_headers: list[str] = Field(
48+
default_factory=list,
49+
description=(
50+
"Additional outbound header names to block in forward_headers. "
51+
"Names are matched case-insensitively and added to the core blocked list. "
52+
"This can tighten policy but cannot unblock core security-sensitive headers."
5553
),
5654
)
5755

58-
@field_validator("forward_headers")
59-
@classmethod
60-
def validate_forward_headers(cls, v: dict[str, str]) -> dict[str, str]:
61-
errors: list[str] = []
62-
for provider_key, header_name in v.items():
63-
if provider_key.startswith("__"):
64-
errors.append(f"provider key '{provider_key}' uses reserved __ prefix")
65-
if header_name.lower() in _BLOCKED_HEADERS:
66-
errors.append(f"header '{header_name}' is blocked (security-sensitive)")
67-
if errors:
68-
raise ValueError(f"invalid forward_headers: {'; '.join(errors)}")
69-
return v
56+
@model_validator(mode="after")
57+
def validate_forward_headers(self) -> "PassthroughSafetyConfig":
58+
validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers)
59+
return self
7060

7161
@classmethod
7262
def sample_run_config(
7363
cls,
7464
base_url: str = "${env.PASSTHROUGH_SAFETY_URL}",
7565
api_key: str = "${env.PASSTHROUGH_SAFETY_API_KEY:=}",
7666
forward_headers: dict[str, str] | None = None,
67+
extra_blocked_headers: list[str] | None = None,
7768
**kwargs: Any,
7869
) -> dict[str, Any]:
7970
config: dict[str, Any] = {
@@ -82,4 +73,6 @@ def sample_run_config(
8273
}
8374
if forward_headers:
8475
config["forward_headers"] = forward_headers
76+
if extra_blocked_headers:
77+
config["extra_blocked_headers"] = extra_blocked_headers
8578
return config

src/llama_stack/providers/remote/safety/passthrough/passthrough.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Any
1010

1111
import httpx
12-
from pydantic import SecretStr
1312

1413
from llama_stack.core.request_headers import NeedsRequestProviderData
14+
from llama_stack.log import get_logger
15+
from llama_stack.providers.utils.forward_headers import build_forwarded_headers
16+
17+
logger = get_logger(__name__, category="safety")
1518
from llama_stack_api import (
1619
GetShieldRequest,
1720
ModerationObject,
@@ -69,37 +72,27 @@ def _get_api_key(self) -> str | None:
6972

7073
def _build_forward_headers(self) -> dict[str, str]:
7174
"""Build outbound headers from provider data using the forward_headers mapping."""
72-
if not self.config.forward_headers:
73-
return {}
74-
7575
provider_data = self.get_request_provider_data()
76-
if provider_data is None:
77-
return {}
78-
79-
headers: dict[str, str] = {}
80-
raw = provider_data.model_dump()
81-
for provider_key, header_name in self.config.forward_headers.items():
82-
value = raw.get(provider_key)
83-
if value is not None:
84-
# unwrap SecretStr so we forward the real value, not '**********'
85-
if isinstance(value, SecretStr):
86-
value = value.get_secret_value()
87-
# strip control chars that could enable header injection
88-
sanitized = str(value).replace("\r", "").replace("\n", "")
89-
headers[header_name] = sanitized
90-
return headers
76+
forwarded = build_forwarded_headers(provider_data, self.config.forward_headers)
77+
if self.config.forward_headers and not forwarded:
78+
logger.warning(
79+
"forward_headers is configured but no matching keys found in provider data — "
80+
"outbound request may be unauthenticated"
81+
)
82+
return forwarded
9183

9284
def _build_request_headers(self) -> dict[str, str]:
93-
"""Combine auth + forwarded headers for the downstream request."""
94-
headers: dict[str, str] = {"Content-Type": "application/json"}
85+
"""Combine auth + forwarded headers for the downstream request.
9586
96-
# forwarded headers go first so config api_key can't be overwritten
87+
Forwarded headers go first; static api_key overwrites Authorization if set.
88+
build_forwarded_headers() normalizes header names case-insensitively so
89+
there are no duplicate Authorization variants in the forwarded dict.
90+
"""
91+
headers: dict[str, str] = {"Content-Type": "application/json"}
9792
headers.update(self._build_forward_headers())
98-
9993
api_key = self._get_api_key()
10094
if api_key:
10195
headers["Authorization"] = f"Bearer {api_key}"
102-
10396
return headers
10497

10598
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:

0 commit comments

Comments
 (0)