Skip to content

feat: Warn on insecure tool invocation with authentication #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from inspect import Signature
from types import MappingProxyType
from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union
from warnings import warn

from aiohttp import ClientSession

Expand Down Expand Up @@ -118,6 +119,17 @@ def __init__(
# map of client headers to their value/callable/coroutine
self.__client_headers = client_headers

# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if (
required_authn_params or required_authz_tokens or client_headers
) and not self.__url.startswith("https://"):
warn(
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)

@property
def _name(self) -> str:
return self.__name__
Expand Down Expand Up @@ -327,7 +339,8 @@ def add_auth_token_getters(
)

return self.__copy(
# create a read-only map for updated getters, params and tokens that are still required
# create read-only values for updated getters, params and tokens
# that are still required
auth_service_token_getters=MappingProxyType(new_getters),
required_authn_params=MappingProxyType(new_req_authn_params),
required_authz_tokens=tuple(new_req_authz_tokens),
Expand Down
173 changes: 155 additions & 18 deletions packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
from typing import AsyncGenerator, Callable, Mapping
from unittest.mock import AsyncMock, Mock
from warnings import catch_warnings, simplefilter

import pytest
import pytest_asyncio
Expand All @@ -27,6 +28,7 @@
from toolbox_core.tool import ToolboxTool, create_func_docstring, resolve_value

TEST_BASE_URL = "http://toolbox.example.com"
HTTPS_BASE_URL = "https://toolbox.example.com"
TEST_TOOL_NAME = "sample_tool"


Expand Down Expand Up @@ -195,7 +197,7 @@ async def test_tool_creation_callable_and_run(
Tests creating a ToolboxTool, checks callability, and simulates a run.
"""
tool_name = TEST_TOOL_NAME
base_url = TEST_BASE_URL
base_url = HTTPS_BASE_URL
invoke_url = f"{base_url}/api/tool/{tool_name}/invoke"

input_args = {"message": "hello world", "count": 5}
Expand Down Expand Up @@ -246,7 +248,7 @@ async def test_tool_run_with_pydantic_validation_error(
due to Pydantic validation *before* making an HTTP request.
"""
tool_name = TEST_TOOL_NAME
base_url = TEST_BASE_URL
base_url = HTTPS_BASE_URL
invoke_url = f"{base_url}/api/tool/{tool_name}/invoke"

with aioresponses() as m:
Expand Down Expand Up @@ -340,18 +342,25 @@ async def test_resolve_value_async_callable():

def test_tool_init_basic(http_session, sample_tool_params, sample_tool_description):
"""Tests basic tool initialization without headers or auth."""
tool_instance = ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
name=TEST_TOOL_NAME,
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
)
with catch_warnings(record=True) as record:
simplefilter("always")

tool_instance = ToolboxTool(
session=http_session,
base_url=HTTPS_BASE_URL,
name=TEST_TOOL_NAME,
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
)
assert (
len(record) == 0
), f"ToolboxTool instantiation unexpectedly warned: {[f'{w.category.__name__}: {w.message}' for w in record]}"

assert tool_instance.__name__ == TEST_TOOL_NAME
assert inspect.iscoroutinefunction(tool_instance.__call__)
assert "message" in tool_instance.__signature__.parameters
Expand All @@ -367,7 +376,7 @@ def test_tool_init_with_client_headers(
"""Tests tool initialization *with* client headers."""
tool_instance = ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
base_url=HTTPS_BASE_URL,
name=TEST_TOOL_NAME,
description=sample_tool_description,
params=sample_tool_params,
Expand Down Expand Up @@ -395,7 +404,7 @@ def test_tool_init_header_auth_conflict(
):
ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
base_url=HTTPS_BASE_URL,
name="auth_conflict_tool",
description=sample_tool_description,
params=sample_tool_auth_params,
Expand All @@ -418,7 +427,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
"""
tool_instance = ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
base_url=HTTPS_BASE_URL,
name="tool_with_client_header",
description=sample_tool_description,
params=sample_tool_params,
Expand Down Expand Up @@ -454,7 +463,7 @@ def test_add_auth_token_getters_unused_token(
"""
tool_instance = ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
base_url=HTTPS_BASE_URL,
name=TEST_TOOL_NAME,
description=sample_tool_description,
params=sample_tool_params,
Expand All @@ -469,3 +478,131 @@ def test_add_auth_token_getters_unused_token(

with pytest.raises(ValueError, match=expected_error_message):
tool_instance.add_auth_token_getters(unused_auth_getters)


# --- Test for the HTTP Warning ---
@pytest.mark.parametrize(
"trigger_condition_params",
[
{"client_headers": {"X-Some-Header": "value"}},
{"required_authn_params": {"param1": ["auth-service1"]}},
{"required_authz_tokens": ["auth-service2"]},
{
"client_headers": {"X-Some-Header": "value"},
"required_authn_params": {"param1": ["auth-service1"]},
},
{
"client_headers": {"X-Some-Header": "value"},
"required_authz_tokens": ["auth-service2"],
},
{
"required_authn_params": {"param1": ["auth-service1"]},
"required_authz_tokens": ["auth-service2"],
},
{
"client_headers": {"X-Some-Header": "value"},
"required_authn_params": {"param1": ["auth-service1"]},
"required_authz_tokens": ["auth-service2"],
},
],
ids=[
"client_headers_only",
"authn_params_only",
"authz_tokens_only",
"headers_and_authn",
"headers_and_authz",
"authn_and_authz",
"all_three_conditions",
],
)
def test_tool_init_http_warning_when_sensitive_info_over_http(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
trigger_condition_params: dict,
):
"""
Tests that a UserWarning is issued if client headers, auth params, or
auth tokens are present and the base_url is HTTP.
"""
expected_warning_message = (
"Sending ID token over HTTP. User data may be exposed. "
"Use HTTPS for secure communication."
)

init_kwargs = {
"session": http_session,
"base_url": TEST_BASE_URL,
"name": "http_warning_tool",
"description": sample_tool_description,
"params": sample_tool_params,
"required_authn_params": {},
"required_authz_tokens": [],
"auth_service_token_getters": {},
"bound_params": {},
"client_headers": {},
}
# Apply the specific conditions for this parametrized test
init_kwargs.update(trigger_condition_params)

with pytest.warns(UserWarning, match=expected_warning_message):
ToolboxTool(**init_kwargs)


def test_tool_init_no_http_warning_if_https(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
static_client_header: dict,
):
"""
Tests that NO UserWarning is issued if client headers are present but
the base_url is HTTPS.
"""
with catch_warnings(record=True) as record:
simplefilter("always")

ToolboxTool(
session=http_session,
base_url=HTTPS_BASE_URL,
name="https_tool",
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers=static_client_header,
)
assert (
len(record) == 0
), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}"


def test_tool_init_no_http_warning_if_no_sensitive_info_on_http(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
):
"""
Tests that NO UserWarning is issued if the URL is HTTP but there are
no client headers, auth params, or auth tokens.
"""
with catch_warnings(record=True) as record:
simplefilter("always")

ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
name="http_tool_no_sensitive",
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
)
assert (
len(record) == 0
), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}"