Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"ddgs>=9.5.5",
"typer>=0.16.1",
"trafilatura>=1.6.1,<1.7",
"truststore>=0.10.4,<1.0",
"selectolax>=0.4.0,<0.5",
"langchain>=1.0.3",
"langchain-chroma>=1.0.0",
Expand Down
3 changes: 2 additions & 1 deletion src/ursa/agents/acquisition_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ursa.agents.base import BaseAgent
from ursa.agents.rag_agent import RAGAgent
from ursa.util.http import build_httpx_client
from ursa.util.parse import (
_derive_filename_from_cd_or_url,
_download_stream_to,
Expand Down Expand Up @@ -113,7 +114,7 @@ def _download(url: str, dest_path: str, timeout: int = 20) -> str:
def describe_image(image: Image.Image) -> str:
if OpenAI is None:
return ""
client = OpenAI()
client = OpenAI(http_client=build_httpx_client())
buf = BytesIO()
image.save(buf, format="PNG")
import base64
Expand Down
12 changes: 5 additions & 7 deletions src/ursa/agents/arxiv_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

from ursa.agents.base import BaseAgent
from ursa.agents.rag_agent import RAGAgent

try:
from openai import OpenAI
except Exception:
pass
from ursa.util.http import build_httpx_client


class PaperMetadata(TypedDict):
Expand All @@ -39,12 +35,14 @@ class PaperState(TypedDict, total=False):


def describe_image(image: Image.Image) -> str:
if "OpenAI" not in globals():
try:
from openai import OpenAI
except ImportError:
print(
"Vision transformer for summarizing images currently only implemented for OpenAI API."
)
return ""
client = OpenAI()
client = OpenAI(http_client=build_httpx_client())

buffered = BytesIO()
image.save(buffered, format="PNG")
Expand Down
2 changes: 2 additions & 0 deletions src/ursa/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
deep_merge_dicts,
dict_diff,
)
from ursa.util.http import inject_truststore_into_ssl

set_parsing_settings(docstring_parse_attribute_docstrings=True)

Expand Down Expand Up @@ -71,6 +72,7 @@ def resolve_config(cfg) -> UrsaConfig:


def main(args=None):
inject_truststore_into_ssl()
parser = build_parser()
cfg = parser.parse_args(args=args)
ursa_config = resolve_config(cfg)
Expand Down
35 changes: 32 additions & 3 deletions src/ursa/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
from tempfile import TemporaryDirectory
from typing import Any, Literal

import httpx
import yaml
from jsonargparse import Namespace
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer

from ursa.util.http import (
build_httpx_async_client,
build_httpx_client,
httpx_verify_value,
)
from ursa.util.mcp import ServerParameters, _serialize_server_config

LoggingLevel = Literal[
Expand All @@ -38,14 +42,39 @@ class ModelConfig(BaseModel):
ssl_verify: bool = True
"""Flag for verifying SSL certs. during API access"""

def _provider(self) -> str:
return self.model.split(":", 1)[0]

@staticmethod
def _merge_provider_kwargs(
kwargs: dict[str, Any], key: str, extra: dict[str, Any]
) -> None:
current = kwargs.get(key)
if current is None:
kwargs[key] = extra
return
if isinstance(current, dict):
kwargs[key] = {**extra, **current}

@property
def kwargs(self) -> dict:
"""Return a dict suitable for init_chat_model/init_embedding_model
Removes parameters set to `None`
"""
kwargs = {k: v for k, v in self.model_dump().items() if v is not None}
if kwargs.pop("ssl_verify", None) is False:
kwargs["http_client"] = httpx.Client(verify=False)
ssl_verify = kwargs.pop("ssl_verify", True)
provider = self._provider()
if provider in {"openai", "azure_openai"}:
kwargs["http_client"] = build_httpx_client(verify=ssl_verify)
kwargs["http_async_client"] = build_httpx_async_client(
verify=ssl_verify
)
elif provider == "ollama":
self._merge_provider_kwargs(
kwargs,
"client_kwargs",
{"verify": httpx_verify_value(verify=ssl_verify)},
)
if api_key_env := kwargs.pop("api_key_env", None):
kwargs["api_key"] = environ.get(api_key_env, None)
return kwargs
Expand Down
58 changes: 58 additions & 0 deletions src/ursa/util/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import ssl
from typing import Any

import httpx
import truststore

_truststore_injected = False


def truststore_ssl_context() -> ssl.SSLContext:
"""Return an SSL context backed by the system trust store."""
return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)


def httpx_verify_value(*, verify: bool = True) -> bool | ssl.SSLContext:
"""Return the HTTPX verify value used across direct clients and kwargs."""
if verify is False:
return False
return truststore_ssl_context()


def build_httpx_client(*, verify: bool = True, **kwargs: Any) -> httpx.Client:
"""Build an HTTPX client with truststore-based verification by default."""
return httpx.Client(verify=httpx_verify_value(verify=verify), **kwargs)


def build_httpx_async_client(
*, verify: bool = True, **kwargs: Any
) -> httpx.AsyncClient:
"""Build an async HTTPX client with truststore-based verification."""
return httpx.AsyncClient(verify=httpx_verify_value(verify=verify), **kwargs)


def build_mcp_httpx_async_client(
*,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
"""Build an async HTTPX client for MCP HTTP transports using truststore."""
kwargs: dict[str, Any] = {"follow_redirects": True}
if headers is not None:
kwargs["headers"] = headers
if timeout is not None:
kwargs["timeout"] = timeout
if auth is not None:
kwargs["auth"] = auth
return build_httpx_async_client(**kwargs)


def inject_truststore_into_ssl() -> None:
"""Inject truststore into ssl for application entrypoints."""
global _truststore_injected
if not _truststore_injected:
truststore.inject_into_ssl()
_truststore_injected = True
7 changes: 6 additions & 1 deletion src/ursa/util/logo_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from rich.panel import Panel
from rich.text import Text

from ursa.util.http import build_httpx_client

# Reuse a small thread pool so callers can "fire-and-continue" with one line.
_EXEC = ThreadPoolExecutor(max_workers=2, thread_name_prefix="logo-gen")

Expand Down Expand Up @@ -773,7 +775,10 @@ def generate_logo_sync(
for k in ("api_key", "base_url", "organization"):
if k in image_provider_kwargs and image_provider_kwargs[k]:
client_kwargs[k] = image_provider_kwargs[k]
client = OpenAI(**client_kwargs)
client = OpenAI(
http_client=build_httpx_client(),
**client_kwargs,
)

final_size = _normalize_size(size, aspect, mode)
# Scenes tend to look odd with transparent backgrounds; force opaque.
Expand Down
7 changes: 6 additions & 1 deletion src/ursa/util/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
)
from pydantic import BaseModel, BeforeValidator, ValidationError

from ursa.util.http import build_mcp_httpx_async_client


def validate_server_parameters(config: dict):
if not isinstance(config, dict):
Expand Down Expand Up @@ -67,10 +69,13 @@ def start_mcp_client(
for server, config in server_configs.items():
if not isinstance(config, BaseModel):
config = validate_server_parameters(dict(**config))
client_config[server] = {
connection = {
**config.model_dump(),
"transport": transport(config),
}
if isinstance(config, (SseServerParameters, StreamableHttpParameters)):
connection["httpx_client_factory"] = build_mcp_httpx_async_client
client_config[server] = connection
return MultiServerMCPClient(client_config)


Expand Down
3 changes: 3 additions & 0 deletions src/ursa_dashboard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import os
import sys

from ursa.util.http import inject_truststore_into_ssl


def main(argv: list[str] | None = None) -> int:
inject_truststore_into_ssl()
ap = argparse.ArgumentParser(prog="ursa-dashboard")
ap.add_argument(
"--host", default=os.environ.get("URSA_DASHBOARD_HOST", "127.0.0.1")
Expand Down
3 changes: 3 additions & 0 deletions src/ursa_dashboard/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typer

from ursa.util.http import inject_truststore_into_ssl

app = typer.Typer(help="Ursa Dashboard Runner")


Expand All @@ -9,6 +11,7 @@ def main(
port: int = typer.Option(8080, help="The port to bind to."),
):
"""Launch the Ursa Web Dashboard."""
inject_truststore_into_ssl()
try:
import uvicorn

Expand Down
3 changes: 3 additions & 0 deletions src/ursa_dashboard/worker_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path
from typing import Any

from ursa.util.http import inject_truststore_into_ssl


def _normalize_model(model: str) -> str:
# Examples use "openai:gpt-5-mini".
Expand Down Expand Up @@ -64,6 +66,7 @@ def _maybe_run_async(result):


def main() -> int:
inject_truststore_into_ssl()
ap = argparse.ArgumentParser()
ap.add_argument("--agent-id", required=True)
ap.add_argument("--run-id", required=True)
Expand Down
22 changes: 22 additions & 0 deletions tests/cli/test_cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,31 @@ def test_model_config_kwargs_includes_extra():
assert kwargs["model"] == "openai:gpt-5"
assert kwargs["max_completion_tokens"] == 1024
assert "http_client" in kwargs # ssl_verify False triggers custom client
assert "http_async_client" in kwargs
assert kwargs["timeout"] == 30


def test_model_config_openai_uses_truststore_client():
cfg = ModelConfig(model="openai:gpt-5", max_completion_tokens=1024)

kwargs = cfg.kwargs

assert kwargs["model"] == "openai:gpt-5"
assert "http_client" in kwargs
assert "http_async_client" in kwargs


def test_model_config_ollama_uses_client_kwargs():
cfg = ModelConfig(model="ollama:nomic-embed-text:latest")

kwargs = cfg.kwargs

assert kwargs["model"] == "ollama:nomic-embed-text:latest"
assert "http_client" not in kwargs
assert "http_async_client" not in kwargs
assert kwargs["client_kwargs"]["verify"] is not False


def test_api_key_env(monkeypatch, tmp_path):
monkeypatch.setenv("TEST_ENV_API_KEY", "super-secret-key")
parser = build_parser()
Expand Down
42 changes: 42 additions & 0 deletions tests/util/test_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from mcp.client.session_group import (
SseServerParameters,
StreamableHttpParameters,
)

from ursa.util import mcp as mcp_mod


def test_start_mcp_client_adds_httpx_factory_for_sse(monkeypatch):
captured = {}

class DummyClient:
def __init__(self, connections):
captured["connections"] = connections

monkeypatch.setattr(mcp_mod, "MultiServerMCPClient", DummyClient)

mcp_mod.start_mcp_client({
"demo": SseServerParameters(url="https://example.com/sse")
})

conn = captured["connections"]["demo"]
assert conn["transport"] == "sse"
assert conn["httpx_client_factory"] is mcp_mod.build_mcp_httpx_async_client


def test_start_mcp_client_adds_httpx_factory_for_streamable_http(monkeypatch):
captured = {}

class DummyClient:
def __init__(self, connections):
captured["connections"] = connections

monkeypatch.setattr(mcp_mod, "MultiServerMCPClient", DummyClient)

mcp_mod.start_mcp_client({
"demo": StreamableHttpParameters(url="https://example.com/mcp")
})

conn = captured["connections"]["demo"]
assert conn["transport"] == "streamable_http"
assert conn["httpx_client_factory"] is mcp_mod.build_mcp_httpx_async_client
Loading
Loading