From 883d03683308ea1c8662c2a7a420d60d125fce52 Mon Sep 17 00:00:00 2001 From: "Gregory R. Warnes" Date: Thu, 12 Mar 2026 13:38:07 -0400 Subject: [PATCH 1/5] feat(providers): Add GitHub Copilot SDK provider - Add CopilotSdkProvider supporting completion and model listing via the github-copilot-sdk package (JSON-RPC to bundled Copilot CLI) - Support three auth modes: explicit token, env vars (COPILOT_GITHUB_TOKEN / GITHUB_TOKEN / GH_TOKEN), or gh CLI session - Add asyncio.Lock double-checked locking to prevent concurrent CLI process spawning on first use - Register copilot_sdk in LLMProvider enum, pyproject.toml optional dependencies, and provider docs table (alphabetical order) - Add 20 unit tests covering auth resolution, client reuse, completion, model listing, and edge cases (empty messages, cli_url path, errors) Co-Authored-By: Claude Sonnet 4.6 --- docs/src/content/docs/providers.md | 1 + pyproject.toml | 6 +- src/any_llm/constants.py | 1 + src/any_llm/providers/copilot_sdk/__init__.py | 3 + .../providers/copilot_sdk/copilot_sdk.py | 263 +++++++++++++++ tests/conftest.py | 2 + .../providers/test_copilot_sdk_provider.py | 313 ++++++++++++++++++ tests/unit/test_provider.py | 1 + 8 files changed, 589 insertions(+), 1 deletion(-) create mode 100644 src/any_llm/providers/copilot_sdk/__init__.py create mode 100644 src/any_llm/providers/copilot_sdk/copilot_sdk.py create mode 100644 tests/unit/providers/test_copilot_sdk_provider.py diff --git a/docs/src/content/docs/providers.md b/docs/src/content/docs/providers.md index 1f9948d8..9ee25467 100644 --- a/docs/src/content/docs/providers.md +++ b/docs/src/content/docs/providers.md @@ -26,6 +26,7 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`bedrock`](https://aws.amazon.com/bedrock/) | AWS_BEARER_TOKEN_BEDROCK | AWS_ENDPOINT_URL_BEDROCK_RUNTIME | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [`cerebras`](https://docs.cerebras.ai/) | CEREBRAS_API_KEY | CEREBRAS_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`cohere`](https://cohere.com/api) | COHERE_API_KEY | COHERE_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | +| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN / GITHUB_TOKEN / GH_TOKEN | None | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | [`databricks`](https://docs.databricks.com/) | DATABRICKS_TOKEN | DATABRICKS_HOST | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | | [`deepseek`](https://platform.deepseek.com/) | DEEPSEEK_API_KEY | DEEPSEEK_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`fireworks`](https://fireworks.ai/api) | FIREWORKS_API_KEY | FIREWORKS_API_BASE | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | diff --git a/pyproject.toml b/pyproject.toml index 65e3ad38..70030853 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] all = [ - "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm]" + "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm,copilot_sdk]" ] platform = [ @@ -31,6 +31,10 @@ platform = [ perplexity = [] +copilot_sdk = [ + "github-copilot-sdk>=0.1.0", +] + mistral = [ "mistralai>=1.9.3", ] diff --git a/src/any_llm/constants.py b/src/any_llm/constants.py index ffecd300..7470505a 100644 --- a/src/any_llm/constants.py +++ b/src/any_llm/constants.py @@ -53,6 +53,7 @@ class LLMProvider(StrEnum): PERPLEXITY = "perplexity" MINIMAX = "minimax" ZAI = "zai" + COPILOT_SDK = "copilot_sdk" GATEWAY = "gateway" @classmethod diff --git a/src/any_llm/providers/copilot_sdk/__init__.py b/src/any_llm/providers/copilot_sdk/__init__.py new file mode 100644 index 00000000..89ebdb3b --- /dev/null +++ b/src/any_llm/providers/copilot_sdk/__init__.py @@ -0,0 +1,3 @@ +from .copilot_sdk import CopilotSdkProvider + +__all__ = ["CopilotSdkProvider"] diff --git a/src/any_llm/providers/copilot_sdk/copilot_sdk.py b/src/any_llm/providers/copilot_sdk/copilot_sdk.py new file mode 100644 index 00000000..b94af703 --- /dev/null +++ b/src/any_llm/providers/copilot_sdk/copilot_sdk.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import asyncio +import os +import time +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from any_llm.any_llm import AnyLLM + +MISSING_PACKAGES_ERROR: ImportError | None = None +try: + from copilot import CopilotClient, PermissionHandler + from copilot.types import ModelInfo as CopilotModelInfo +except ImportError as e: + MISSING_PACKAGES_ERROR = e + +if TYPE_CHECKING: + from collections.abc import Sequence + + from any_llm.types.completion import ( + ChatCompletion, + ChatCompletionChunk, + CompletionParams, + CreateEmbeddingResponse, + ) + from any_llm.types.model import Model + + +def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: + """Flatten an OpenAI-style messages list into a single prompt string. + + System messages become an instruction header; prior conversational turns + are formatted as a transcript; the final user message is the prompt. + """ + system_parts: list[str] = [] + conversation_parts: list[str] = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + # Handle multimodal content blocks — extract text only. + if isinstance(content, list): + content = " ".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + + if role == "system": + system_parts.append(str(content)) + elif role == "assistant": + conversation_parts.append(f"Assistant: {content}") + else: + conversation_parts.append(f"User: {content}") + + parts: list[str] = [] + if system_parts: + parts.append("\n".join(system_parts)) + if conversation_parts: + parts.append("\n\n".join(conversation_parts)) + return "\n\n".join(parts) + + +def _build_chat_completion(content: str, model_id: str) -> "ChatCompletion": + """Wrap a plain text response into an OpenAI-compatible ChatCompletion.""" + from any_llm.types.completion import ChatCompletion, ChatCompletionMessage, Choice # noqa: PLC0415 + + return ChatCompletion( + id=f"copilot-sdk-{int(time.time())}", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=content), + logprobs=None, + ) + ], + created=int(time.time()), + model=model_id, + object="chat.completion", + ) + + +def _copilot_model_to_openai(info: "CopilotModelInfo") -> "Model": + """Convert a copilot-sdk ModelInfo to an OpenAI-compatible Model.""" + from openai.types.model import Model as OpenAIModel # noqa: PLC0415 + + return OpenAIModel(id=info.id, created=0, owned_by="github-copilot", object="model") + + +class CopilotSdkProvider(AnyLLM): + """GitHub Copilot SDK provider for any-llm. + + Communicates with the Copilot CLI via JSON-RPC using the ``github-copilot-sdk`` + Python package. Authentication supports two modes (checked in order): + + 1. **Token mode** — set ``COPILOT_GITHUB_TOKEN``, ``GITHUB_TOKEN``, or + ``GH_TOKEN`` in the environment (or pass ``api_key`` explicitly). + 2. **Logged-in CLI user** — if no token is found, the Copilot CLI uses the + credentials from the local user's authenticated ``gh`` / ``copilot`` CLI session. + + The Copilot CLI binary must be installed and reachable on ``PATH`` (or pointed + to via ``COPILOT_CLI_PATH``). An external CLI server can be used instead by + setting ``COPILOT_CLI_URL`` (e.g. ``localhost:9000``). + + Environment variables: + COPILOT_GITHUB_TOKEN: GitHub token with Copilot access (optional). + GITHUB_TOKEN / GH_TOKEN: Fallback GitHub token sources (optional). + COPILOT_CLI_URL: Connect to an external CLI server instead of spawning one. + COPILOT_CLI_PATH: Override the CLI binary path (default: PATH lookup). + """ + + PROVIDER_NAME = "copilot_sdk" + ENV_API_KEY_NAME = "COPILOT_GITHUB_TOKEN" + ENV_API_BASE_NAME = "COPILOT_CLI_URL" + PROVIDER_DOCUMENTATION_URL = "https://github.com/github/copilot-sdk" + + SUPPORTS_COMPLETION = True + SUPPORTS_COMPLETION_STREAMING = False + SUPPORTS_COMPLETION_REASONING = False + SUPPORTS_COMPLETION_IMAGE = False + SUPPORTS_COMPLETION_PDF = False + SUPPORTS_EMBEDDING = False + SUPPORTS_RESPONSES = False + SUPPORTS_LIST_MODELS = True + SUPPORTS_BATCH = False + + MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + + # Internal state — populated lazily on first call. + _copilot_client: "CopilotClient | None" + + # ------------------------------------------------------------------ auth -- + + @override + def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: + """API key is optional: logged-in CLI mode works without any token.""" + resolved = ( + api_key + or os.getenv("COPILOT_GITHUB_TOKEN") + or os.getenv("GITHUB_TOKEN") + or os.getenv("GH_TOKEN") + ) + # Return None (not empty string) so copilot-sdk uses logged-in credentials. + return resolved or None + + # ---------------------------------------------------------- client init -- + + @override + def _init_client(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None: + """Store resolved auth options; actual CLI start is deferred to the first async call.""" + self._resolved_token: str | None = api_key + self._cli_url: str | None = api_base or os.getenv("COPILOT_CLI_URL") + self._cli_path: str | None = os.getenv("COPILOT_CLI_PATH") or None + self._extra_kwargs = kwargs + self._copilot_client = None + self._client_lock = asyncio.Lock() + + async def _ensure_client(self) -> "CopilotClient": + """Lazily create and start a CopilotClient, reusing it across calls. + + The lock prevents concurrent coroutines from each spawning a separate + CLI process when the client has not yet been initialized. + """ + if self._copilot_client is not None: + return self._copilot_client + + async with self._client_lock: + # Re-check inside the lock: another coroutine may have initialized + # the client while we were waiting. + if self._copilot_client is not None: + return self._copilot_client + + opts: dict[str, Any] = {} + if self._cli_url: + opts["cli_url"] = self._cli_url + else: + if self._cli_path: + opts["cli_path"] = self._cli_path + if self._resolved_token: + opts["github_token"] = self._resolved_token + + # Pass None (not {}) so the SDK uses its own defaults when no + # options are configured. + self._copilot_client = CopilotClient(opts or None) + await self._copilot_client.start() + + return self._copilot_client + + # --------------------------------------------------- completion (async) -- + + @override + async def _acompletion( + self, + params: "CompletionParams", + **kwargs: Any, + ) -> "ChatCompletion": + """Send a completion request via a fresh Copilot session.""" + client = await self._ensure_client() + prompt = _messages_to_prompt(params.messages) + model_id = params.model_id + + # approve_all silently grants any permissions the session requests + # (e.g. tool calls). This mirrors how the Copilot CLI behaves in + # non-interactive mode and is appropriate for programmatic usage. + session_cfg: dict[str, Any] = {"on_permission_request": PermissionHandler.approve_all} + if model_id: + session_cfg["model"] = model_id + + # A new session is created per request; the Copilot SDK session + # lifecycle is lightweight (no separate process per session). + session = await client.create_session(session_cfg) + try: + event = await session.send_and_wait({"prompt": prompt}) + content = "" + if event is not None and hasattr(event, "data") and hasattr(event.data, "content"): + content = event.data.content or "" + return _build_chat_completion(content, model_id or self.PROVIDER_NAME) + finally: + await session.disconnect() + + # -------------------------------------------------- model listing (async) -- + + @override + async def _alist_models(self, **kwargs: Any) -> "Sequence[Model]": + """List models available through the Copilot CLI.""" + client = await self._ensure_client() + models = await client.list_models() + return [_copilot_model_to_openai(m) for m in models] + + # ------------ Required abstract stubs (unused — _acompletion overridden) -- + + @staticmethod + @override + def _convert_completion_params(params: "CompletionParams", **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") + + @staticmethod + @override + def _convert_completion_response(response: Any) -> "ChatCompletion": + raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") + + @staticmethod + @override + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> "ChatCompletionChunk": + raise NotImplementedError("CopilotSdkProvider does not support streaming (MVP)") + + @staticmethod + @override + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("CopilotSdkProvider does not support embeddings") + + @staticmethod + @override + def _convert_embedding_response(response: Any) -> "CreateEmbeddingResponse": + raise NotImplementedError("CopilotSdkProvider does not support embeddings") + + @staticmethod + @override + def _convert_list_models_response(response: Any) -> "Sequence[Model]": + raise NotImplementedError("CopilotSdkProvider uses _alist_models directly") diff --git a/tests/conftest.py b/tests/conftest.py index 5caaea1d..db7f47aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,6 +81,7 @@ def provider_model_map() -> dict[LLMProvider, str]: LLMProvider.LLAMACPP: "N/A", LLMProvider.MINIMAX: "MiniMax-M2", LLMProvider.ZAI: "glm-4-32b-0414-128k", + LLMProvider.COPILOT_SDK: "gpt-4o", } @@ -151,6 +152,7 @@ def provider_client_config() -> dict[LLMProvider, dict[str, Any]]: "api_base": "https://mlrun-me8bof5t-eastus2.cognitiveservices.azure.com/", "api_version": "2025-03-01-preview", }, + LLMProvider.COPILOT_SDK: {}, } diff --git a/tests/unit/providers/test_copilot_sdk_provider.py b/tests/unit/providers/test_copilot_sdk_provider.py new file mode 100644 index 00000000..afb4a062 --- /dev/null +++ b/tests/unit/providers/test_copilot_sdk_provider.py @@ -0,0 +1,313 @@ +"""Unit tests for CopilotSdkProvider.""" +from __future__ import annotations + +import asyncio +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from any_llm.providers.copilot_sdk.copilot_sdk import ( + CopilotSdkProvider, + _build_chat_completion, + _copilot_model_to_openai, + _messages_to_prompt, +) +from any_llm.types.completion import CompletionParams + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_provider(**kwargs: Any) -> CopilotSdkProvider: + """Instantiate CopilotSdkProvider while bypassing real CLI startup.""" + p = object.__new__(CopilotSdkProvider) + p._resolved_token = kwargs.get("api_key") + p._cli_url = kwargs.get("api_base") + p._cli_path = None + p._extra_kwargs = {} + p._copilot_client = None + p._client_lock = asyncio.Lock() + return p + + +def _make_model_info(model_id: str = "gpt-4o", name: str = "GPT-4o") -> Any: + """Return a minimal mock ModelInfo.""" + m = MagicMock() + m.id = model_id + m.name = name + return m + + +# --------------------------------------------------------------------------- +# _messages_to_prompt +# --------------------------------------------------------------------------- + +def test_messages_to_prompt_simple_user() -> None: + msgs = [{"role": "user", "content": "Hello"}] + assert _messages_to_prompt(msgs) == "User: Hello" + + +def test_messages_to_prompt_system_prepended() -> None: + msgs = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Hi"}, + ] + result = _messages_to_prompt(msgs) + assert result.startswith("Be concise.") + assert "User: Hi" in result + + +def test_messages_to_prompt_multi_turn() -> None: + msgs = [ + {"role": "user", "content": "Ping"}, + {"role": "assistant", "content": "Pong"}, + {"role": "user", "content": "Again"}, + ] + result = _messages_to_prompt(msgs) + assert "User: Ping" in result + assert "Assistant: Pong" in result + assert "User: Again" in result + + +def test_messages_to_prompt_multimodal_extracts_text() -> None: + msgs = [{"role": "user", "content": [{"type": "text", "text": "Tell me"}, {"type": "image_url"}]}] + result = _messages_to_prompt(msgs) + assert "Tell me" in result + + +# --------------------------------------------------------------------------- +# _build_chat_completion +# --------------------------------------------------------------------------- + +def test_build_chat_completion_structure() -> None: + cc = _build_chat_completion("Hello world", "gpt-4o") + assert cc.choices[0].message.content == "Hello world" + assert cc.model == "gpt-4o" + assert cc.choices[0].finish_reason == "stop" + assert cc.object == "chat.completion" + + +# --------------------------------------------------------------------------- +# _copilot_model_to_openai +# --------------------------------------------------------------------------- + +def test_copilot_model_to_openai_maps_id() -> None: + info = _make_model_info("claude-sonnet-4-5", "Claude Sonnet 4.5") + model = _copilot_model_to_openai(info) + assert model.id == "claude-sonnet-4-5" + assert model.owned_by == "github-copilot" + assert model.object == "model" + + +# --------------------------------------------------------------------------- +# Provider class attributes +# --------------------------------------------------------------------------- + +def test_provider_required_attributes() -> None: + assert CopilotSdkProvider.PROVIDER_NAME == "copilot_sdk" + assert CopilotSdkProvider.SUPPORTS_COMPLETION is True + assert CopilotSdkProvider.SUPPORTS_LIST_MODELS is True + assert CopilotSdkProvider.SUPPORTS_EMBEDDING is False + assert CopilotSdkProvider.SUPPORTS_COMPLETION_STREAMING is False + assert CopilotSdkProvider.MISSING_PACKAGES_ERROR is None + + +# --------------------------------------------------------------------------- +# Auth: _verify_and_set_api_key +# --------------------------------------------------------------------------- + +def test_api_key_explicit_wins() -> None: + p = object.__new__(CopilotSdkProvider) + with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "env-token"}, clear=False): + result = p._verify_and_set_api_key("explicit-token") + assert result == "explicit-token" + + +def test_api_key_falls_back_to_copilot_env() -> None: + p = object.__new__(CopilotSdkProvider) + with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "copilot-env"}, clear=False): + result = p._verify_and_set_api_key(None) + assert result == "copilot-env" + + +def test_api_key_falls_back_to_github_token() -> None: + p = object.__new__(CopilotSdkProvider) + env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "gh-token"} + with patch.dict(os.environ, env, clear=False): + result = p._verify_and_set_api_key(None) + assert result == "gh-token" + + +def test_api_key_returns_none_when_all_absent() -> None: + """No token at all → None (triggers logged-in CLI user mode).""" + p = object.__new__(CopilotSdkProvider) + env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "", "GH_TOKEN": ""} + with patch.dict(os.environ, env, clear=False): + result = p._verify_and_set_api_key(None) + assert result is None + + +# --------------------------------------------------------------------------- +# _init_client +# --------------------------------------------------------------------------- + +def test_init_client_stores_token_and_url() -> None: + p = object.__new__(CopilotSdkProvider) + p._init_client(api_key="my-token", api_base="localhost:9000") + assert p._resolved_token == "my-token" + assert p._cli_url == "localhost:9000" + assert p._copilot_client is None + + +def test_init_client_reads_cli_url_from_env() -> None: + p = object.__new__(CopilotSdkProvider) + with patch.dict(os.environ, {"COPILOT_CLI_URL": "localhost:7777"}, clear=False): + p._init_client(api_key=None, api_base=None) + assert p._cli_url == "localhost:7777" + + +# --------------------------------------------------------------------------- +# _acompletion (async) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_acompletion_returns_chat_completion() -> None: + provider = _make_provider(api_key="test-token") + + mock_event = MagicMock() + mock_event.data.content = "The answer is 4." + + mock_session = AsyncMock() + mock_session.send_and_wait = AsyncMock(return_value=mock_event) + mock_session.disconnect = AsyncMock() + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "The answer is 4." + assert result.model == "gpt-4o" + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_handles_none_event() -> None: + """When send_and_wait returns None, content is empty string.""" + provider = _make_provider(api_key="test-token") + + mock_session = AsyncMock() + mock_session.send_and_wait = AsyncMock(return_value=None) + mock_session.disconnect = AsyncMock() + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "" + + +@pytest.mark.asyncio +async def test_acompletion_disconnects_session_on_error() -> None: + """Session.disconnect() is called even when send_and_wait raises.""" + provider = _make_provider(api_key="test-token") + + mock_session = AsyncMock() + mock_session.send_and_wait = AsyncMock(side_effect=RuntimeError("CLI died")) + mock_session.disconnect = AsyncMock() + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + with pytest.raises(RuntimeError, match="CLI died"): + await provider._acompletion(params) + + mock_session.disconnect.assert_called_once() + + +# --------------------------------------------------------------------------- +# _alist_models (async) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_alist_models_converts_model_infos() -> None: + provider = _make_provider(api_key="test-token") + + raw_models = [ + _make_model_info("gpt-4o", "GPT-4o"), + _make_model_info("claude-sonnet-4-5", "Claude Sonnet 4.5"), + ] + mock_client = AsyncMock() + mock_client.list_models = AsyncMock(return_value=raw_models) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + models = await provider._alist_models() + + assert len(models) == 2 + ids = {m.id for m in models} + assert ids == {"gpt-4o", "claude-sonnet-4-5"} + for m in models: + assert m.owned_by == "github-copilot" + + +# --------------------------------------------------------------------------- +# _ensure_client (async) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_ensure_client_reuses_existing_instance() -> None: + """Second call returns the same client without calling start() again.""" + provider = _make_provider(api_key="test-token") + + mock_client = AsyncMock() + mock_client.start = AsyncMock() + provider._copilot_client = mock_client # pre-seed as already initialized + + result = await provider._ensure_client() + + assert result is mock_client + mock_client.start.assert_not_called() + + +@pytest.mark.asyncio +async def test_ensure_client_uses_cli_url_path() -> None: + """When _cli_url is set, only cli_url is passed (no token or cli_path).""" + provider = _make_provider(api_key="tok", api_base="localhost:9000") + + mock_client = AsyncMock() + mock_client.start = AsyncMock() + + with patch("any_llm.providers.copilot_sdk.copilot_sdk.CopilotClient", return_value=mock_client) as mock_cls: + result = await provider._ensure_client() + + assert result is mock_client + mock_cls.assert_called_once_with({"cli_url": "localhost:9000"}) + mock_client.start.assert_called_once() + + +# --------------------------------------------------------------------------- +# _messages_to_prompt edge cases +# --------------------------------------------------------------------------- + +def test_messages_to_prompt_empty_list() -> None: + """Empty messages list returns an empty string without raising.""" + assert _messages_to_prompt([]) == "" diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index 3f3b8bb9..35bdb29e 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -149,6 +149,7 @@ def test_providers_raise_MissingApiKeyError(provider: LLMProvider) -> None: LLMProvider.VERTEXAI, LLMProvider.VLLM, LLMProvider.GATEWAY, + LLMProvider.COPILOT_SDK, # uses CLI auth; no API key required ): pytest.skip("This provider handles `api_key` differently.") with patch.dict(os.environ, {}, clear=True): From a412deceb2df676412a1821e593a38f702e95e68 Mon Sep 17 00:00:00 2001 From: "Gregory R. Warnes" Date: Thu, 12 Mar 2026 14:16:25 -0400 Subject: [PATCH 2/5] fix(copilot_sdk): Raise on SESSION_ERROR, fix mimetypes thread safety MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Raise RuntimeError when SESSION_ERROR fires during streaming instead of silently completing the stream with no error - Call mimetypes.init() at module load to ensure guess_extension() is thread-safe from the first invocation - Add tests for SESSION_ERROR raises, image attachment forwarding, and update provider capability docs (streaming/reasoning/image now ✅) Co-Authored-By: Claude Sonnet 4.6 --- docs/src/content/docs/providers.md | 2 +- .../providers/copilot_sdk/copilot_sdk.py | 248 +++++++++- .../providers/test_copilot_sdk_provider.py | 427 +++++++++++++++++- 3 files changed, 639 insertions(+), 38 deletions(-) diff --git a/docs/src/content/docs/providers.md b/docs/src/content/docs/providers.md index 9ee25467..5492ff27 100644 --- a/docs/src/content/docs/providers.md +++ b/docs/src/content/docs/providers.md @@ -26,7 +26,7 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`bedrock`](https://aws.amazon.com/bedrock/) | AWS_BEARER_TOKEN_BEDROCK | AWS_ENDPOINT_URL_BEDROCK_RUNTIME | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [`cerebras`](https://docs.cerebras.ai/) | CEREBRAS_API_KEY | CEREBRAS_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`cohere`](https://cohere.com/api) | COHERE_API_KEY | COHERE_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | -| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN / GITHUB_TOKEN / GH_TOKEN | None | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | +| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN / GITHUB_TOKEN / GH_TOKEN | None | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | [`databricks`](https://docs.databricks.com/) | DATABRICKS_TOKEN | DATABRICKS_HOST | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | | [`deepseek`](https://platform.deepseek.com/) | DEEPSEEK_API_KEY | DEEPSEEK_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`fireworks`](https://fireworks.ai/api) | FIREWORKS_API_KEY | FIREWORKS_API_BASE | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | diff --git a/src/any_llm/providers/copilot_sdk/copilot_sdk.py b/src/any_llm/providers/copilot_sdk/copilot_sdk.py index b94af703..3f87175c 100644 --- a/src/any_llm/providers/copilot_sdk/copilot_sdk.py +++ b/src/any_llm/providers/copilot_sdk/copilot_sdk.py @@ -1,7 +1,10 @@ from __future__ import annotations import asyncio +import base64 +import mimetypes import os +import tempfile import time from typing import TYPE_CHECKING, Any @@ -9,15 +12,20 @@ from any_llm.any_llm import AnyLLM +# Eagerly initialise the MIME-type database so that mimetypes.guess_extension() +# is thread-safe from the very first call (lazy initialisation is not thread-safe). +mimetypes.init() + MISSING_PACKAGES_ERROR: ImportError | None = None try: from copilot import CopilotClient, PermissionHandler + from copilot.generated.session_events import SessionEventType from copilot.types import ModelInfo as CopilotModelInfo except ImportError as e: MISSING_PACKAGES_ERROR = e if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import AsyncIterator, Sequence from any_llm.types.completion import ( ChatCompletion, @@ -28,11 +36,17 @@ from any_llm.types.model import Model +# --------------------------------------------------------------------------- +# Message / prompt helpers +# --------------------------------------------------------------------------- + def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: """Flatten an OpenAI-style messages list into a single prompt string. System messages become an instruction header; prior conversational turns are formatted as a transcript; the final user message is the prompt. + Image content blocks are intentionally omitted here — they are forwarded + as file attachments via :func:`_extract_attachments`. """ system_parts: list[str] = [] conversation_parts: list[str] = [] @@ -40,7 +54,7 @@ def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") - # Handle multimodal content blocks — extract text only. + # Multimodal content blocks: extract text only; images handled separately. if isinstance(content, list): content = " ".join( block.get("text", "") @@ -63,9 +77,73 @@ def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: return "\n\n".join(parts) -def _build_chat_completion(content: str, model_id: str) -> "ChatCompletion": +def _extract_attachments(messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[str]]: + """Extract ``image_url`` content blocks from messages as file attachments. + + Decodes base64 ``data:`` URIs to temporary files on disk, which the + Copilot SDK then passes to the CLI as ``FileAttachment`` objects. + + Returns: + (attachments, temp_paths) where ``temp_paths`` must be cleaned up + by the caller after the ``session.send()`` call completes. + """ + attachments: list[dict[str, Any]] = [] + temp_paths: list[str] = [] + + for msg in messages: + content = msg.get("content", "") + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict) or block.get("type") != "image_url": + continue + url = (block.get("image_url") or {}).get("url", "") + if not url.startswith("data:"): + # HTTP/HTTPS URLs would need downloading — skipped for now. + continue + try: + header, b64data = url.split(",", 1) + mime = header.split(";")[0].split(":")[1] + ext = mimetypes.guess_extension(mime) or ".bin" + # guess_extension returns ".jpe" for image/jpeg on some platforms. + if ext == ".jpe": + ext = ".jpg" + raw = base64.b64decode(b64data, validate=True) + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as fh: + fh.write(raw) + temp_paths.append(fh.name) + attachments.append({"type": "file", "path": fh.name}) + except Exception: # noqa: BLE001 + pass # Skip malformed data URIs rather than failing the whole request. + + return attachments, temp_paths + + +def _cleanup_temp_files(paths: list[str]) -> None: + """Remove temporary image files created by :func:`_extract_attachments`.""" + for path in paths: + try: + os.unlink(path) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Response / chunk builders +# --------------------------------------------------------------------------- + +def _build_chat_completion( + content: str, + model_id: str, + reasoning: str | None = None, +) -> "ChatCompletion": """Wrap a plain text response into an OpenAI-compatible ChatCompletion.""" - from any_llm.types.completion import ChatCompletion, ChatCompletionMessage, Choice # noqa: PLC0415 + from any_llm.types.completion import ( # noqa: PLC0415 + ChatCompletion, + ChatCompletionMessage, + Choice, + Reasoning, + ) return ChatCompletion( id=f"copilot-sdk-{int(time.time())}", @@ -73,7 +151,11 @@ def _build_chat_completion(content: str, model_id: str) -> "ChatCompletion": Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(role="assistant", content=content), + message=ChatCompletionMessage( + role="assistant", + content=content, + reasoning=Reasoning(content=reasoning) if reasoning else None, + ), logprobs=None, ) ], @@ -83,6 +165,34 @@ def _build_chat_completion(content: str, model_id: str) -> "ChatCompletion": ) +def _build_chunk(delta: str, model_id: str, *, is_reasoning: bool = False) -> "ChatCompletionChunk": + """Wrap a streaming delta into an OpenAI-compatible ChatCompletionChunk.""" + from any_llm.types.completion import ( # noqa: PLC0415 + ChatCompletionChunk, + ChunkChoice, + ChoiceDelta, + Reasoning, + ) + + return ChatCompletionChunk( + id=f"copilot-sdk-{time.time_ns()}", + choices=[ + ChunkChoice( + delta=ChoiceDelta( + role="assistant", + content=None if is_reasoning else delta, + reasoning=Reasoning(content=delta) if is_reasoning else None, + ), + finish_reason=None, + index=0, + ) + ], + created=int(time.time()), + model=model_id, + object="chat.completion.chunk", + ) + + def _copilot_model_to_openai(info: "CopilotModelInfo") -> "Model": """Convert a copilot-sdk ModelInfo to an OpenAI-compatible Model.""" from openai.types.model import Model as OpenAIModel # noqa: PLC0415 @@ -90,6 +200,10 @@ def _copilot_model_to_openai(info: "CopilotModelInfo") -> "Model": return OpenAIModel(id=info.id, created=0, owned_by="github-copilot", object="model") +# --------------------------------------------------------------------------- +# Provider +# --------------------------------------------------------------------------- + class CopilotSdkProvider(AnyLLM): """GitHub Copilot SDK provider for any-llm. @@ -102,8 +216,17 @@ class CopilotSdkProvider(AnyLLM): credentials from the local user's authenticated ``gh`` / ``copilot`` CLI session. The Copilot CLI binary must be installed and reachable on ``PATH`` (or pointed - to via ``COPILOT_CLI_PATH``). An external CLI server can be used instead by - setting ``COPILOT_CLI_URL`` (e.g. ``localhost:9000``). + to via ``COPILOT_CLI_PATH``). The platform-specific ``github-copilot-sdk`` + wheel (e.g. ``github-copilot-sdk==X.Y.Z`` targeting your OS/arch) bundles the + binary automatically. An external CLI server can be used instead by setting + ``COPILOT_CLI_URL`` (e.g. ``localhost:9000``). + + **Supported features:** + + * Completion (non-streaming and streaming) + * Reasoning (``reasoning_effort``: ``low`` / ``medium`` / ``high`` / ``xhigh``) + * Image attachments via ``image_url`` content blocks (``data:`` URIs only) + * Model listing Environment variables: COPILOT_GITHUB_TOKEN: GitHub token with Copilot access (optional). @@ -118,9 +241,9 @@ class CopilotSdkProvider(AnyLLM): PROVIDER_DOCUMENTATION_URL = "https://github.com/github/copilot-sdk" SUPPORTS_COMPLETION = True - SUPPORTS_COMPLETION_STREAMING = False - SUPPORTS_COMPLETION_REASONING = False - SUPPORTS_COMPLETION_IMAGE = False + SUPPORTS_COMPLETION_STREAMING = True + SUPPORTS_COMPLETION_REASONING = True + SUPPORTS_COMPLETION_IMAGE = True SUPPORTS_COMPLETION_PDF = False SUPPORTS_EMBEDDING = False SUPPORTS_RESPONSES = False @@ -152,8 +275,8 @@ def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: def _init_client(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None: """Store resolved auth options; actual CLI start is deferred to the first async call.""" self._resolved_token: str | None = api_key - self._cli_url: str | None = api_base or os.getenv("COPILOT_CLI_URL") - self._cli_path: str | None = os.getenv("COPILOT_CLI_PATH") or None + self._cli_url: str | None = api_base or os.getenv("COPILOT_CLI_URL") + self._cli_path: str | None = os.getenv("COPILOT_CLI_PATH") or None self._extra_kwargs = kwargs self._copilot_client = None self._client_lock = asyncio.Lock() @@ -191,35 +314,112 @@ async def _ensure_client(self) -> "CopilotClient": # --------------------------------------------------- completion (async) -- + def _build_session_cfg(self, params: "CompletionParams", streaming: bool) -> dict[str, Any]: + """Build a SessionConfig dict from CompletionParams.""" + # approve_all silently grants any permissions the session requests + # (e.g. tool calls). This mirrors how the Copilot CLI behaves in + # non-interactive mode and is appropriate for programmatic usage. + cfg: dict[str, Any] = {"on_permission_request": PermissionHandler.approve_all} + if params.model_id: + cfg["model"] = params.model_id + # Pass reasoning_effort through; omit values the SDK doesn't accept. + if params.reasoning_effort and params.reasoning_effort not in ("auto", "none"): + cfg["reasoning_effort"] = params.reasoning_effort + if streaming: + cfg["streaming"] = True + return cfg + + async def _stream_from_session( + self, + session: Any, + msg_opts: dict[str, Any], + model_id: str, + temp_paths: list[str], + ) -> "AsyncIterator[ChatCompletionChunk]": + """Async generator that streams chunks from a live Copilot session. + + Bridges the SDK's event-callback model to an async iterator via an + asyncio Queue. The session is disconnected and temp files cleaned up + when the generator exits (normally or via exception/cancellation). + """ + queue: asyncio.Queue[tuple[str, str] | None] = asyncio.Queue() + + def on_event(event: Any) -> None: + etype = event.type + if etype == SessionEventType.ASSISTANT_MESSAGE_DELTA: + queue.put_nowait(("content", event.data.delta_content or "")) + elif etype == SessionEventType.ASSISTANT_REASONING_DELTA: + queue.put_nowait(("reasoning", event.data.delta_content or "")) + elif etype == SessionEventType.SESSION_IDLE: + queue.put_nowait(None) # sentinel — streaming complete normally + elif etype == SessionEventType.SESSION_ERROR: + queue.put_nowait(("error", getattr(getattr(event, "data", None), "message", "Copilot session error"))) + + unsubscribe = session.on(on_event) + try: + await session.send(msg_opts) + while True: + item = await queue.get() + if item is None: + break + kind, delta = item + if kind == "error": + raise RuntimeError(delta) + yield _build_chunk(delta, model_id, is_reasoning=(kind == "reasoning")) + finally: + unsubscribe() + await session.disconnect() + _cleanup_temp_files(temp_paths) + @override async def _acompletion( self, params: "CompletionParams", **kwargs: Any, - ) -> "ChatCompletion": + ) -> "ChatCompletion | AsyncIterator[ChatCompletionChunk]": """Send a completion request via a fresh Copilot session.""" client = await self._ensure_client() prompt = _messages_to_prompt(params.messages) - model_id = params.model_id - - # approve_all silently grants any permissions the session requests - # (e.g. tool calls). This mirrors how the Copilot CLI behaves in - # non-interactive mode and is appropriate for programmatic usage. - session_cfg: dict[str, Any] = {"on_permission_request": PermissionHandler.approve_all} - if model_id: - session_cfg["model"] = model_id + model_id = params.model_id or self.PROVIDER_NAME + attachments, temp_paths = _extract_attachments(params.messages) + session_cfg = self._build_session_cfg(params, streaming=bool(params.stream)) # A new session is created per request; the Copilot SDK session # lifecycle is lightweight (no separate process per session). session = await client.create_session(session_cfg) + + msg_opts: dict[str, Any] = {"prompt": prompt} + if attachments: + msg_opts["attachments"] = attachments + + if params.stream: + # Return the async generator directly; it owns the session lifecycle. + return self._stream_from_session(session, msg_opts, model_id, temp_paths) + + # Non-streaming: capture reasoning alongside the final message event. try: - event = await session.send_and_wait({"prompt": prompt}) + reasoning_content: str | None = None + + def on_reasoning(event: Any) -> None: + nonlocal reasoning_content + if event.type == SessionEventType.ASSISTANT_REASONING: + reasoning_content = ( + getattr(getattr(event, "data", None), "content", None) or None + ) + + unsubscribe = session.on(on_reasoning) + try: + event = await session.send_and_wait(msg_opts) + finally: + unsubscribe() + content = "" if event is not None and hasattr(event, "data") and hasattr(event.data, "content"): content = event.data.content or "" - return _build_chat_completion(content, model_id or self.PROVIDER_NAME) + return _build_chat_completion(content, model_id, reasoning_content) finally: await session.disconnect() + _cleanup_temp_files(temp_paths) # -------------------------------------------------- model listing (async) -- @@ -245,7 +445,7 @@ def _convert_completion_response(response: Any) -> "ChatCompletion": @staticmethod @override def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> "ChatCompletionChunk": - raise NotImplementedError("CopilotSdkProvider does not support streaming (MVP)") + raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") @staticmethod @override diff --git a/tests/unit/providers/test_copilot_sdk_provider.py b/tests/unit/providers/test_copilot_sdk_provider.py index afb4a062..c780f3b0 100644 --- a/tests/unit/providers/test_copilot_sdk_provider.py +++ b/tests/unit/providers/test_copilot_sdk_provider.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import base64 import os from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -11,7 +12,10 @@ from any_llm.providers.copilot_sdk.copilot_sdk import ( CopilotSdkProvider, _build_chat_completion, + _build_chunk, + _cleanup_temp_files, _copilot_model_to_openai, + _extract_attachments, _messages_to_prompt, ) from any_llm.types.completion import CompletionParams @@ -111,7 +115,9 @@ def test_provider_required_attributes() -> None: assert CopilotSdkProvider.SUPPORTS_COMPLETION is True assert CopilotSdkProvider.SUPPORTS_LIST_MODELS is True assert CopilotSdkProvider.SUPPORTS_EMBEDDING is False - assert CopilotSdkProvider.SUPPORTS_COMPLETION_STREAMING is False + assert CopilotSdkProvider.SUPPORTS_COMPLETION_STREAMING is True + assert CopilotSdkProvider.SUPPORTS_COMPLETION_REASONING is True + assert CopilotSdkProvider.SUPPORTS_COMPLETION_IMAGE is True assert CopilotSdkProvider.MISSING_PACKAGES_ERROR is None @@ -173,6 +179,18 @@ def test_init_client_reads_cli_url_from_env() -> None: # _acompletion (async) # --------------------------------------------------------------------------- +def _make_session(send_and_wait_result: Any = None, send_and_wait_error: Exception | None = None) -> Any: + """Build a mock CopilotSession with a sync on() and async send_and_wait/disconnect.""" + session = MagicMock() + session.on = MagicMock(return_value=lambda: None) # sync; returns unsubscribe callable + session.disconnect = AsyncMock() + if send_and_wait_error: + session.send_and_wait = AsyncMock(side_effect=send_and_wait_error) + else: + session.send_and_wait = AsyncMock(return_value=send_and_wait_result) + return session + + @pytest.mark.asyncio async def test_acompletion_returns_chat_completion() -> None: provider = _make_provider(api_key="test-token") @@ -180,10 +198,7 @@ async def test_acompletion_returns_chat_completion() -> None: mock_event = MagicMock() mock_event.data.content = "The answer is 4." - mock_session = AsyncMock() - mock_session.send_and_wait = AsyncMock(return_value=mock_event) - mock_session.disconnect = AsyncMock() - + mock_session = _make_session(send_and_wait_result=mock_event) mock_client = AsyncMock() mock_client.create_session = AsyncMock(return_value=mock_session) @@ -204,10 +219,7 @@ async def test_acompletion_handles_none_event() -> None: """When send_and_wait returns None, content is empty string.""" provider = _make_provider(api_key="test-token") - mock_session = AsyncMock() - mock_session.send_and_wait = AsyncMock(return_value=None) - mock_session.disconnect = AsyncMock() - + mock_session = _make_session(send_and_wait_result=None) mock_client = AsyncMock() mock_client.create_session = AsyncMock(return_value=mock_session) @@ -226,10 +238,7 @@ async def test_acompletion_disconnects_session_on_error() -> None: """Session.disconnect() is called even when send_and_wait raises.""" provider = _make_provider(api_key="test-token") - mock_session = AsyncMock() - mock_session.send_and_wait = AsyncMock(side_effect=RuntimeError("CLI died")) - mock_session.disconnect = AsyncMock() - + mock_session = _make_session(send_and_wait_error=RuntimeError("CLI died")) mock_client = AsyncMock() mock_client.create_session = AsyncMock(return_value=mock_session) @@ -311,3 +320,395 @@ async def test_ensure_client_uses_cli_url_path() -> None: def test_messages_to_prompt_empty_list() -> None: """Empty messages list returns an empty string without raising.""" assert _messages_to_prompt([]) == "" + + +# --------------------------------------------------------------------------- +# _build_chat_completion — reasoning field +# --------------------------------------------------------------------------- + +def test_build_chat_completion_with_reasoning() -> None: + cc = _build_chat_completion("Answer", "gpt-4o", reasoning="Because math.") + assert cc.choices[0].message.reasoning is not None + assert cc.choices[0].message.reasoning.content == "Because math." + + +def test_build_chat_completion_no_reasoning() -> None: + cc = _build_chat_completion("Answer", "gpt-4o") + assert cc.choices[0].message.reasoning is None + + +# --------------------------------------------------------------------------- +# _build_chunk +# --------------------------------------------------------------------------- + +def test_build_chunk_content_delta() -> None: + chunk = _build_chunk("hello", "gpt-4o", is_reasoning=False) + assert chunk.choices[0].delta.content == "hello" + assert chunk.choices[0].delta.reasoning is None + assert chunk.object == "chat.completion.chunk" + + +def test_build_chunk_reasoning_delta() -> None: + chunk = _build_chunk("step 1", "gpt-4o", is_reasoning=True) + assert chunk.choices[0].delta.content is None + assert chunk.choices[0].delta.reasoning is not None + assert chunk.choices[0].delta.reasoning.content == "step 1" + + +# --------------------------------------------------------------------------- +# _extract_attachments +# --------------------------------------------------------------------------- + +def _make_data_uri(mime: str = "image/jpeg", content: bytes = b"\xff\xd8\xff") -> str: + return f"data:{mime};base64,{base64.b64encode(content).decode()}" + + +def test_extract_attachments_base64_image() -> None: + msgs = [{"role": "user", "content": [ + {"type": "text", "text": "Look at this"}, + {"type": "image_url", "image_url": {"url": _make_data_uri()}}, + ]}] + attachments, temp_paths = _extract_attachments(msgs) + try: + assert len(attachments) == 1 + assert attachments[0]["type"] == "file" + assert os.path.exists(attachments[0]["path"]) + assert len(temp_paths) == 1 + finally: + _cleanup_temp_files(temp_paths) + + +def test_extract_attachments_http_url_skipped() -> None: + """HTTP image URLs are not supported (no download) and must be silently skipped.""" + msgs = [{"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, + ]}] + attachments, temp_paths = _extract_attachments(msgs) + assert attachments == [] + assert temp_paths == [] + + +def test_extract_attachments_no_images() -> None: + msgs = [{"role": "user", "content": "Plain text, no images."}] + attachments, temp_paths = _extract_attachments(msgs) + assert attachments == [] + assert temp_paths == [] + + +def test_extract_attachments_malformed_data_uri_skipped() -> None: + msgs = [{"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,NOT_VALID!!"}}, + ]}] + # Should not raise; malformed URIs are silently dropped. + attachments, temp_paths = _extract_attachments(msgs) + assert attachments == [] + + +def test_cleanup_temp_files_removes_files() -> None: + import tempfile + with tempfile.NamedTemporaryFile(delete=False) as fh: + path = fh.name + assert os.path.exists(path) + _cleanup_temp_files([path]) + assert not os.path.exists(path) + + +def test_cleanup_temp_files_tolerates_missing() -> None: + """Cleaning up a path that doesn't exist must not raise.""" + _cleanup_temp_files(["/tmp/does-not-exist-copilot-sdk-test-xyz"]) + + +# --------------------------------------------------------------------------- +# _acompletion — reasoning captured in non-streaming mode +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_acompletion_captures_reasoning() -> None: + """assistant.reasoning event content is surfaced in ChatCompletion.reasoning.""" + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + mock_msg_event = MagicMock() + mock_msg_event.data.content = "42" + mock_msg_event.type = SessionEventType.ASSISTANT_MESSAGE + + mock_reasoning_event = MagicMock() + mock_reasoning_event.data.content = "Because 6×7=42." + mock_reasoning_event.type = SessionEventType.ASSISTANT_REASONING + + def fake_on(callback: Any) -> Any: + """Immediately fire ASSISTANT_REASONING, then return a no-op unsubscribe.""" + callback(mock_reasoning_event) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.send_and_wait = AsyncMock(return_value=mock_msg_event) + mock_session.disconnect = AsyncMock() + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "What is 6×7?"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "42" + assert result.choices[0].message.reasoning is not None + assert result.choices[0].message.reasoning.content == "Because 6×7=42." + + +# --------------------------------------------------------------------------- +# _acompletion — reasoning_effort passed to session_cfg +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_acompletion_passes_reasoning_effort_to_session() -> None: + provider = _make_provider(api_key="test-token") + + mock_session = _make_session() + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Think hard."}], + reasoning_effort="high", + ) + await provider._acompletion(params) + + call_cfg = mock_client.create_session.call_args[0][0] + assert call_cfg.get("reasoning_effort") == "high" + + +@pytest.mark.asyncio +async def test_acompletion_omits_auto_reasoning_effort() -> None: + """reasoning_effort='auto' must NOT be forwarded (SDK doesn't accept it).""" + provider = _make_provider(api_key="test-token") + + mock_session = _make_session() + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + reasoning_effort="auto", + ) + await provider._acompletion(params) + + call_cfg = mock_client.create_session.call_args[0][0] + assert "reasoning_effort" not in call_cfg + + +# --------------------------------------------------------------------------- +# _acompletion — streaming +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_acompletion_streaming_yields_chunks() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + # Build fake events that the on() callback will receive. + def _evt(etype: SessionEventType, delta: str) -> MagicMock: + e = MagicMock() + e.type = etype + e.data.delta_content = delta + return e + + delta_events = [ + _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, "Hello"), + _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, " world"), + _evt(SessionEventType.SESSION_IDLE, ""), + ] + + registered_callback: list[Any] = [] + + def fake_on(cb: Any) -> Any: + registered_callback.append(cb) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for evt in delta_events: + registered_callback[0](evt) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + stream = await provider._acompletion(params) + chunks = [c async for c in stream] + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello" + assert chunks[1].choices[0].delta.content == " world" + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_streaming_reasoning_chunks() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + def _evt(etype: SessionEventType, delta: str) -> MagicMock: + e = MagicMock() + e.type = etype + e.data.delta_content = delta + return e + + delta_events = [ + _evt(SessionEventType.ASSISTANT_REASONING_DELTA, "step 1"), + _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, "answer"), + _evt(SessionEventType.SESSION_IDLE, ""), + ] + + registered_callback: list[Any] = [] + + def fake_on(cb: Any) -> Any: + registered_callback.append(cb) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for evt in delta_events: + registered_callback[0](evt) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Think step by step."}], + stream=True, + ) + stream = await provider._acompletion(params) + chunks = [c async for c in stream] + + assert len(chunks) == 2 + # First chunk carries reasoning + assert chunks[0].choices[0].delta.reasoning is not None + assert chunks[0].choices[0].delta.reasoning.content == "step 1" + assert chunks[0].choices[0].delta.content is None + # Second chunk carries content + assert chunks[1].choices[0].delta.content == "answer" + assert chunks[1].choices[0].delta.reasoning is None + + +@pytest.mark.asyncio +async def test_acompletion_streaming_session_error_raises() -> None: + """SESSION_ERROR event must raise RuntimeError rather than completing silently.""" + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + def _evt(etype: SessionEventType, msg: str = "") -> MagicMock: + e = MagicMock() + e.type = etype + e.data.message = msg + e.data.delta_content = "" + return e + + delta_events = [ + _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, ""), # one content delta first + _evt(SessionEventType.SESSION_ERROR, "CLI crashed"), + ] + # Override delta_content for the first event + delta_events[0].data.delta_content = "partial" + + registered_callback: list[Any] = [] + + def fake_on(cb: Any) -> Any: + registered_callback.append(cb) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for evt in delta_events: + registered_callback[0](evt) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + stream = await provider._acompletion(params) + with pytest.raises(RuntimeError): + async for _ in stream: + pass + + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_image_attachments_forwarded() -> None: + """Image attachments extracted from messages must be included in msg_opts sent to session.""" + import base64 + + provider = _make_provider(api_key="test-token") + + jpeg_data = base64.b64encode(b"\xff\xd8\xff").decode() + data_uri = f"data:image/jpeg;base64,{jpeg_data}" + + mock_event = MagicMock() + mock_event.data.content = "I see an image." + mock_session = _make_session(send_and_wait_result=mock_event) + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ], + }], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "I see an image." + # send_and_wait must have received msg_opts with 'attachments' + call_kwargs = mock_session.send_and_wait.call_args[0][0] + assert "attachments" in call_kwargs + assert len(call_kwargs["attachments"]) == 1 + assert call_kwargs["attachments"][0]["type"] == "file" From 2ab9dd80c1efcb9ec8c0a4b95c2475d4cac98b14 Mon Sep 17 00:00:00 2001 From: "Gregory R. Warnes" Date: Thu, 12 Mar 2026 14:33:58 -0400 Subject: [PATCH 3/5] fix(copilot_sdk): Fix factory lookup and add provider documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Copilot_sdkProvider alias to __init__.py so AnyLLM._create_provider() can find the class (factory uses provider_key.capitalize() → "Copilot_sdk") - Add "Provider Notes" section to providers.md with installation guide, authentication modes, configuration env vars, and usage examples - Regenerate provider table via generate_provider_table.py; copilot_sdk row now shows correct COPILOT_CLI_URL in the Base column Co-Authored-By: Claude Sonnet 4.6 --- docs/src/content/docs/providers.md | 80 ++++++++++++++++++- src/any_llm/providers/copilot_sdk/__init__.py | 6 +- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/docs/src/content/docs/providers.md b/docs/src/content/docs/providers.md index 5492ff27..6d634f9b 100644 --- a/docs/src/content/docs/providers.md +++ b/docs/src/content/docs/providers.md @@ -26,7 +26,7 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`bedrock`](https://aws.amazon.com/bedrock/) | AWS_BEARER_TOKEN_BEDROCK | AWS_ENDPOINT_URL_BEDROCK_RUNTIME | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [`cerebras`](https://docs.cerebras.ai/) | CEREBRAS_API_KEY | CEREBRAS_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`cohere`](https://cohere.com/api) | COHERE_API_KEY | COHERE_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | -| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN / GITHUB_TOKEN / GH_TOKEN | None | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | +| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN | COPILOT_CLI_URL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | [`databricks`](https://docs.databricks.com/) | DATABRICKS_TOKEN | DATABRICKS_HOST | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | | [`deepseek`](https://platform.deepseek.com/) | DEEPSEEK_API_KEY | DEEPSEEK_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`fireworks`](https://fireworks.ai/api) | FIREWORKS_API_KEY | FIREWORKS_API_BASE | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | @@ -60,3 +60,81 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`xai`](https://x.ai/) | XAI_API_KEY | XAI_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`zai`](https://docs.z.ai/guides/develop/python/introduction) | ZAI_API_KEY | ZAI_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | + +## Provider Notes + +### `copilot_sdk` — GitHub Copilot SDK + +The `copilot_sdk` provider communicates with GitHub Copilot models via the +[`github-copilot-sdk`](https://pypi.org/project/github-copilot-sdk/) Python package, +which bundles the Copilot CLI binary for your platform. + +#### Installation + +Install the platform-specific wheel: + +```bash +pip install any-llm-sdk[copilot_sdk] +``` + +> **Note**: `github-copilot-sdk` ships separate wheels per OS and CPU architecture +> (e.g. `macosx_arm64`, `linux_x86_64`). `pip` selects the correct wheel automatically +> on supported platforms. If installation fails, check [PyPI](https://pypi.org/project/github-copilot-sdk/#files) +> for available platform tags. + +#### Authentication + +Two modes are supported, checked in order: + +1. **Token mode** — set one of these environment variables: + + ```bash + export COPILOT_GITHUB_TOKEN="ghp_your_token" + # or + export GITHUB_TOKEN="ghp_your_token" + # or + export GH_TOKEN="ghp_your_token" + ``` + + Alternatively, pass `api_key` directly to `AnyLLM.create()`. + +2. **Logged-in CLI user** — if no token is set, the provider uses the credentials + from the local `gh` / `copilot` CLI session (run `gh auth login` first). No + environment variable is required in this mode. + +#### Configuration + +| Environment Variable | Purpose | Default | +| --- | --- | --- | +| `COPILOT_GITHUB_TOKEN` | GitHub token with Copilot access | — | +| `GITHUB_TOKEN` / `GH_TOKEN` | Fallback token sources | — | +| `COPILOT_CLI_URL` | Connect to an external CLI server instead of spawning one (e.g. `localhost:9000`) | — | +| `COPILOT_CLI_PATH` | Override the Copilot CLI binary path | PATH lookup | + +#### Usage + +```python +from any_llm import AnyLLM + +# Token auth (or set COPILOT_GITHUB_TOKEN in environment) +llm = AnyLLM.create("copilot_sdk") + +# List available models +models = llm.list_models() + +# Completion with reasoning +response = llm.completion( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Explain async generators in Python."}], + reasoning_effort="high", +) +print(response.choices[0].message.content) + +# Streaming +for chunk in llm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello!"}], + stream=True, +): + print(chunk.choices[0].delta.content or "", end="", flush=True) +``` diff --git a/src/any_llm/providers/copilot_sdk/__init__.py b/src/any_llm/providers/copilot_sdk/__init__.py index 89ebdb3b..54b385ca 100644 --- a/src/any_llm/providers/copilot_sdk/__init__.py +++ b/src/any_llm/providers/copilot_sdk/__init__.py @@ -1,3 +1,7 @@ from .copilot_sdk import CopilotSdkProvider -__all__ = ["CopilotSdkProvider"] +# Factory alias: AnyLLM._create_provider() derives class names via +# provider_key.capitalize() + "Provider", which yields "Copilot_sdkProvider". +Copilot_sdkProvider = CopilotSdkProvider + +__all__ = ["CopilotSdkProvider", "Copilot_sdkProvider"] From 68454fc83c36a1e210e7a079b97df620395259af Mon Sep 17 00:00:00 2001 From: "Gregory R. Warnes" Date: Sat, 14 Mar 2026 19:39:02 -0400 Subject: [PATCH 4/5] chore: add .github/copilot-instructions.md for VS Code Copilot --- .github/copilot-instructions.md | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..c2c46bd3 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,51 @@ +# Copilot Instructions for any-llm + +> Full guidelines are in [AGENTS.md](../AGENTS.md). This file surfaces the rules most critical for AI-assisted coding. + +## Commands + +```bash +# Setup +uv venv && source .venv/bin/activate && uv sync --all-extras -U + +# Run all checks (preferred before committing) +uv run pre-commit run --all-files --verbose + +# Tests +uv run pytest -v tests/unit +uv run pytest -v tests/integration -n auto # requires API keys +``` + +## Code Style (enforced by mypy + ruff) + +- **Type hints required** on all new code; mypy runs in strict mode +- **`@override` decorator** from `typing_extensions` is required on every method that overrides a base class method — mypy enforces `explicit-override`. For static methods: `@staticmethod` first, then `@override` +- **Direct attribute access** (`obj.field`) preferred over `getattr(obj, "field")` for typed fields +- Line length: 120 chars (ruff) +- No decorative section-separator comments (`# ------` banners) + +## Project Structure + +``` +src/any_llm/ + providers// ← all provider-specific code goes here + types/ ← shared types + gateway/ ← optional FastAPI gateway +tests/ + unit/ ← no API keys needed + integration/ ← skip when creds unavailable + gateway/ +``` + +## Testing Rules + +- **No class-based test grouping** — all tests are standalone functions +- Add happy path + error/raise path tests for every change (~85% coverage target) +- Integration tests must `pytest.skip(...)` when credentials are unavailable +- Optional-dependency imports (e.g. `mistralai`, `cohere`) go **inside** the test function, not at the top of the file + +## Commits & PRs + +- Conventional Commits: `feat(scope): ...`, `fix: ...`, `chore(deps): ...`, `tests: ...` +- PRs must complete the checklist in `.github/pull_request_template.md` and include AI-usage disclosure when applicable +- Never commit secrets — use env vars or a gitignored `.env` From c20339f46a7bc89e7e7959885d6c55a59e0831c3 Mon Sep 17 00:00:00 2001 From: "Gregory R. Warnes" Date: Fri, 27 Mar 2026 22:50:56 -0400 Subject: [PATCH 5/5] fix(providers): align Copilot SDK provider with maintainer review - rename provider key and enum from copilot_sdk to copilotsdk across code, tests, docs, and extras - remove alias/banners and keep implementation style aligned with project conventions - add optional SDK test guard and retain defensive event-field handling comments - drop unrelated .github/copilot-instructions.md from this PR --- .github/copilot-instructions.md | 51 -- docs/src/content/docs/providers.md | 10 +- pyproject.toml | 4 +- src/any_llm/constants.py | 2 +- src/any_llm/providers/copilot_sdk/__init__.py | 7 - src/any_llm/providers/copilotsdk/__init__.py | 3 + .../copilotsdk.py} | 80 +-- tests/conftest.py | 4 +- ...rovider.py => test_copilotsdk_provider.py} | 489 ++++++++---------- tests/unit/test_provider.py | 2 +- 10 files changed, 245 insertions(+), 407 deletions(-) delete mode 100644 .github/copilot-instructions.md delete mode 100644 src/any_llm/providers/copilot_sdk/__init__.py create mode 100644 src/any_llm/providers/copilotsdk/__init__.py rename src/any_llm/providers/{copilot_sdk/copilot_sdk.py => copilotsdk/copilotsdk.py} (85%) rename tests/unit/providers/{test_copilot_sdk_provider.py => test_copilotsdk_provider.py} (58%) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index c2c46bd3..00000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,51 +0,0 @@ -# Copilot Instructions for any-llm - -> Full guidelines are in [AGENTS.md](../AGENTS.md). This file surfaces the rules most critical for AI-assisted coding. - -## Commands - -```bash -# Setup -uv venv && source .venv/bin/activate && uv sync --all-extras -U - -# Run all checks (preferred before committing) -uv run pre-commit run --all-files --verbose - -# Tests -uv run pytest -v tests/unit -uv run pytest -v tests/integration -n auto # requires API keys -``` - -## Code Style (enforced by mypy + ruff) - -- **Type hints required** on all new code; mypy runs in strict mode -- **`@override` decorator** from `typing_extensions` is required on every method that overrides a base class method — mypy enforces `explicit-override`. For static methods: `@staticmethod` first, then `@override` -- **Direct attribute access** (`obj.field`) preferred over `getattr(obj, "field")` for typed fields -- Line length: 120 chars (ruff) -- No decorative section-separator comments (`# ------` banners) - -## Project Structure - -``` -src/any_llm/ - providers// ← all provider-specific code goes here - types/ ← shared types - gateway/ ← optional FastAPI gateway -tests/ - unit/ ← no API keys needed - integration/ ← skip when creds unavailable - gateway/ -``` - -## Testing Rules - -- **No class-based test grouping** — all tests are standalone functions -- Add happy path + error/raise path tests for every change (~85% coverage target) -- Integration tests must `pytest.skip(...)` when credentials are unavailable -- Optional-dependency imports (e.g. `mistralai`, `cohere`) go **inside** the test function, not at the top of the file - -## Commits & PRs - -- Conventional Commits: `feat(scope): ...`, `fix: ...`, `chore(deps): ...`, `tests: ...` -- PRs must complete the checklist in `.github/pull_request_template.md` and include AI-usage disclosure when applicable -- Never commit secrets — use env vars or a gitignored `.env` diff --git a/docs/src/content/docs/providers.md b/docs/src/content/docs/providers.md index 51e8fdf0..cc882a96 100644 --- a/docs/src/content/docs/providers.md +++ b/docs/src/content/docs/providers.md @@ -27,7 +27,7 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`bedrock`](https://aws.amazon.com/bedrock/) | AWS_BEARER_TOKEN_BEDROCK | AWS_ENDPOINT_URL_BEDROCK_RUNTIME | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [`cerebras`](https://docs.cerebras.ai/) | CEREBRAS_API_KEY | CEREBRAS_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`cohere`](https://cohere.com/api) | COHERE_API_KEY | COHERE_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | -| [`copilot_sdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN | COPILOT_CLI_URL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | +| [`copilotsdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN | COPILOT_CLI_URL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | [`databricks`](https://docs.databricks.com/) | DATABRICKS_TOKEN | DATABRICKS_HOST | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | | [`deepseek`](https://platform.deepseek.com/) | DEEPSEEK_API_KEY | DEEPSEEK_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`fireworks`](https://fireworks.ai/api) | FIREWORKS_API_KEY | FIREWORKS_API_BASE | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | @@ -64,9 +64,9 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla ## Provider Notes -### `copilot_sdk` — GitHub Copilot SDK +### `copilotsdk` — GitHub Copilot SDK -The `copilot_sdk` provider communicates with GitHub Copilot models via the +The `copilotsdk` provider communicates with GitHub Copilot models via the [`github-copilot-sdk`](https://pypi.org/project/github-copilot-sdk/) Python package, which bundles the Copilot CLI binary for your platform. @@ -75,7 +75,7 @@ which bundles the Copilot CLI binary for your platform. Install the platform-specific wheel: ```bash -pip install any-llm-sdk[copilot_sdk] +pip install any-llm-sdk[copilotsdk] ``` > **Note**: `github-copilot-sdk` ships separate wheels per OS and CPU architecture @@ -118,7 +118,7 @@ Two modes are supported, checked in order: from any_llm import AnyLLM # Token auth (or set COPILOT_GITHUB_TOKEN in environment) -llm = AnyLLM.create("copilot_sdk") +llm = AnyLLM.create("copilotsdk") # List available models models = llm.list_models() diff --git a/pyproject.toml b/pyproject.toml index e31d7b02..4a125c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] all = [ - "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureanthropic,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm,copilot_sdk]" + "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureanthropic,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm,copilotsdk]" ] platform = [ @@ -31,7 +31,7 @@ platform = [ perplexity = [] -copilot_sdk = [ +copilotsdk = [ "github-copilot-sdk>=0.1.0", ] diff --git a/src/any_llm/constants.py b/src/any_llm/constants.py index b6fb5f45..584613e5 100644 --- a/src/any_llm/constants.py +++ b/src/any_llm/constants.py @@ -54,7 +54,7 @@ class LLMProvider(StrEnum): PERPLEXITY = "perplexity" MINIMAX = "minimax" ZAI = "zai" - COPILOT_SDK = "copilot_sdk" + COPILOTSDK = "copilotsdk" GATEWAY = "gateway" @classmethod diff --git a/src/any_llm/providers/copilot_sdk/__init__.py b/src/any_llm/providers/copilot_sdk/__init__.py deleted file mode 100644 index 54b385ca..00000000 --- a/src/any_llm/providers/copilot_sdk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .copilot_sdk import CopilotSdkProvider - -# Factory alias: AnyLLM._create_provider() derives class names via -# provider_key.capitalize() + "Provider", which yields "Copilot_sdkProvider". -Copilot_sdkProvider = CopilotSdkProvider - -__all__ = ["CopilotSdkProvider", "Copilot_sdkProvider"] diff --git a/src/any_llm/providers/copilotsdk/__init__.py b/src/any_llm/providers/copilotsdk/__init__.py new file mode 100644 index 00000000..2f6f2ec5 --- /dev/null +++ b/src/any_llm/providers/copilotsdk/__init__.py @@ -0,0 +1,3 @@ +from .copilotsdk import CopilotsdkProvider + +__all__ = ["CopilotsdkProvider"] diff --git a/src/any_llm/providers/copilot_sdk/copilot_sdk.py b/src/any_llm/providers/copilotsdk/copilotsdk.py similarity index 85% rename from src/any_llm/providers/copilot_sdk/copilot_sdk.py rename to src/any_llm/providers/copilotsdk/copilotsdk.py index 3f87175c..8611ffcb 100644 --- a/src/any_llm/providers/copilot_sdk/copilot_sdk.py +++ b/src/any_llm/providers/copilotsdk/copilotsdk.py @@ -36,10 +36,6 @@ from any_llm.types.model import Model -# --------------------------------------------------------------------------- -# Message / prompt helpers -# --------------------------------------------------------------------------- - def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: """Flatten an OpenAI-style messages list into a single prompt string. @@ -128,10 +124,6 @@ def _cleanup_temp_files(paths: list[str]) -> None: pass -# --------------------------------------------------------------------------- -# Response / chunk builders -# --------------------------------------------------------------------------- - def _build_chat_completion( content: str, model_id: str, @@ -146,7 +138,7 @@ def _build_chat_completion( ) return ChatCompletion( - id=f"copilot-sdk-{int(time.time())}", + id=f"copilotsdk-{time.time_ns()}", choices=[ Choice( finish_reason="stop", @@ -175,7 +167,7 @@ def _build_chunk(delta: str, model_id: str, *, is_reasoning: bool = False) -> "C ) return ChatCompletionChunk( - id=f"copilot-sdk-{time.time_ns()}", + id=f"copilotsdk-{time.time_ns()}", choices=[ ChunkChoice( delta=ChoiceDelta( @@ -200,42 +192,30 @@ def _copilot_model_to_openai(info: "CopilotModelInfo") -> "Model": return OpenAIModel(id=info.id, created=0, owned_by="github-copilot", object="model") -# --------------------------------------------------------------------------- -# Provider -# --------------------------------------------------------------------------- - -class CopilotSdkProvider(AnyLLM): +class CopilotsdkProvider(AnyLLM): """GitHub Copilot SDK provider for any-llm. Communicates with the Copilot CLI via JSON-RPC using the ``github-copilot-sdk`` - Python package. Authentication supports two modes (checked in order): + Python package. Authentication supports two modes (checked in order): 1. **Token mode** — set ``COPILOT_GITHUB_TOKEN``, ``GITHUB_TOKEN``, or ``GH_TOKEN`` in the environment (or pass ``api_key`` explicitly). - 2. **Logged-in CLI user** — if no token is found, the Copilot CLI uses the - credentials from the local user's authenticated ``gh`` / ``copilot`` CLI session. - - The Copilot CLI binary must be installed and reachable on ``PATH`` (or pointed - to via ``COPILOT_CLI_PATH``). The platform-specific ``github-copilot-sdk`` - wheel (e.g. ``github-copilot-sdk==X.Y.Z`` targeting your OS/arch) bundles the - binary automatically. An external CLI server can be used instead by setting - ``COPILOT_CLI_URL`` (e.g. ``localhost:9000``). - - **Supported features:** + 2. **Logged-in CLI user** — if no token is found, the Copilot CLI uses + the credentials from the local user's ``gh`` / ``copilot`` CLI session. - * Completion (non-streaming and streaming) - * Reasoning (``reasoning_effort``: ``low`` / ``medium`` / ``high`` / ``xhigh``) - * Image attachments via ``image_url`` content blocks (``data:`` URIs only) - * Model listing + Supports completion (streaming and non-streaming), reasoning + (``reasoning_effort``), image attachments via ``data:`` URIs, and model + listing. The binary is bundled by the ``github-copilot-sdk`` wheel; + ``COPILOT_CLI_URL`` overrides to an external server. Environment variables: COPILOT_GITHUB_TOKEN: GitHub token with Copilot access (optional). - GITHUB_TOKEN / GH_TOKEN: Fallback GitHub token sources (optional). + GITHUB_TOKEN / GH_TOKEN: Fallback token sources (optional). COPILOT_CLI_URL: Connect to an external CLI server instead of spawning one. COPILOT_CLI_PATH: Override the CLI binary path (default: PATH lookup). """ - PROVIDER_NAME = "copilot_sdk" + PROVIDER_NAME = "copilotsdk" ENV_API_KEY_NAME = "COPILOT_GITHUB_TOKEN" ENV_API_BASE_NAME = "COPILOT_CLI_URL" PROVIDER_DOCUMENTATION_URL = "https://github.com/github/copilot-sdk" @@ -255,8 +235,6 @@ class CopilotSdkProvider(AnyLLM): # Internal state — populated lazily on first call. _copilot_client: "CopilotClient | None" - # ------------------------------------------------------------------ auth -- - @override def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: """API key is optional: logged-in CLI mode works without any token.""" @@ -269,8 +247,6 @@ def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: # Return None (not empty string) so copilot-sdk uses logged-in credentials. return resolved or None - # ---------------------------------------------------------- client init -- - @override def _init_client(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None: """Store resolved auth options; actual CLI start is deferred to the first async call.""" @@ -312,13 +288,12 @@ async def _ensure_client(self) -> "CopilotClient": return self._copilot_client - # --------------------------------------------------- completion (async) -- - def _build_session_cfg(self, params: "CompletionParams", streaming: bool) -> dict[str, Any]: """Build a SessionConfig dict from CompletionParams.""" - # approve_all silently grants any permissions the session requests - # (e.g. tool calls). This mirrors how the Copilot CLI behaves in - # non-interactive mode and is appropriate for programmatic usage. + # PermissionHandler.approve_all silently grants any permissions the session + # requests (e.g. tool calls). This mirrors how the Copilot CLI behaves in + # non-interactive mode and is appropriate for programmatic usage. Callers + # that need a more restrictive policy can subclass and override this method. cfg: dict[str, Any] = {"on_permission_request": PermissionHandler.approve_all} if params.model_id: cfg["model"] = params.model_id @@ -353,7 +328,10 @@ def on_event(event: Any) -> None: elif etype == SessionEventType.SESSION_IDLE: queue.put_nowait(None) # sentinel — streaming complete normally elif etype == SessionEventType.SESSION_ERROR: - queue.put_nowait(("error", getattr(getattr(event, "data", None), "message", "Copilot session error"))) + # SDK SESSION_ERROR event fields are untyped; use getattr with a + # default so this path is safe even if the schema changes. + error_msg = getattr(getattr(event, "data", None), "message", "Copilot session error") + queue.put_nowait(("error", error_msg)) unsubscribe = session.on(on_event) try: @@ -403,6 +381,8 @@ async def _acompletion( def on_reasoning(event: Any) -> None: nonlocal reasoning_content if event.type == SessionEventType.ASSISTANT_REASONING: + # SDK reasoning event fields are untyped; defensive access is + # intentional here so a schema change doesn't cause an AttributeError. reasoning_content = ( getattr(getattr(event, "data", None), "content", None) or None ) @@ -421,8 +401,6 @@ def on_reasoning(event: Any) -> None: await session.disconnect() _cleanup_temp_files(temp_paths) - # -------------------------------------------------- model listing (async) -- - @override async def _alist_models(self, **kwargs: Any) -> "Sequence[Model]": """List models available through the Copilot CLI.""" @@ -430,34 +408,32 @@ async def _alist_models(self, **kwargs: Any) -> "Sequence[Model]": models = await client.list_models() return [_copilot_model_to_openai(m) for m in models] - # ------------ Required abstract stubs (unused — _acompletion overridden) -- - @staticmethod @override def _convert_completion_params(params: "CompletionParams", **kwargs: Any) -> dict[str, Any]: - raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") @staticmethod @override def _convert_completion_response(response: Any) -> "ChatCompletion": - raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") @staticmethod @override def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> "ChatCompletionChunk": - raise NotImplementedError("CopilotSdkProvider overrides _acompletion directly") + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") @staticmethod @override def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: - raise NotImplementedError("CopilotSdkProvider does not support embeddings") + raise NotImplementedError("CopilotsdkProvider does not support embeddings") @staticmethod @override def _convert_embedding_response(response: Any) -> "CreateEmbeddingResponse": - raise NotImplementedError("CopilotSdkProvider does not support embeddings") + raise NotImplementedError("CopilotsdkProvider does not support embeddings") @staticmethod @override def _convert_list_models_response(response: Any) -> "Sequence[Model]": - raise NotImplementedError("CopilotSdkProvider uses _alist_models directly") + raise NotImplementedError("CopilotsdkProvider uses _alist_models directly") diff --git a/tests/conftest.py b/tests/conftest.py index db7f47aa..d46d77de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,7 +81,7 @@ def provider_model_map() -> dict[LLMProvider, str]: LLMProvider.LLAMACPP: "N/A", LLMProvider.MINIMAX: "MiniMax-M2", LLMProvider.ZAI: "glm-4-32b-0414-128k", - LLMProvider.COPILOT_SDK: "gpt-4o", + LLMProvider.COPILOTSDK: "gpt-4o", } @@ -152,7 +152,7 @@ def provider_client_config() -> dict[LLMProvider, dict[str, Any]]: "api_base": "https://mlrun-me8bof5t-eastus2.cognitiveservices.azure.com/", "api_version": "2025-03-01-preview", }, - LLMProvider.COPILOT_SDK: {}, + LLMProvider.COPILOTSDK: {}, } diff --git a/tests/unit/providers/test_copilot_sdk_provider.py b/tests/unit/providers/test_copilotsdk_provider.py similarity index 58% rename from tests/unit/providers/test_copilot_sdk_provider.py rename to tests/unit/providers/test_copilotsdk_provider.py index c780f3b0..315fca14 100644 --- a/tests/unit/providers/test_copilot_sdk_provider.py +++ b/tests/unit/providers/test_copilotsdk_provider.py @@ -1,16 +1,17 @@ -"""Unit tests for CopilotSdkProvider.""" +"""Unit tests for CopilotsdkProvider.""" from __future__ import annotations import asyncio import base64 import os +import tempfile from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -from any_llm.providers.copilot_sdk.copilot_sdk import ( - CopilotSdkProvider, +from any_llm.providers.copilotsdk.copilotsdk import ( + CopilotsdkProvider, _build_chat_completion, _build_chunk, _cleanup_temp_files, @@ -20,83 +21,114 @@ ) from any_llm.types.completion import CompletionParams +pytest.importorskip("copilot") -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def _make_provider(**kwargs: Any) -> CopilotSdkProvider: - """Instantiate CopilotSdkProvider while bypassing real CLI startup.""" - p = object.__new__(CopilotSdkProvider) - p._resolved_token = kwargs.get("api_key") - p._cli_url = kwargs.get("api_base") - p._cli_path = None - p._extra_kwargs = {} - p._copilot_client = None - p._client_lock = asyncio.Lock() - return p +def _make_provider(**kwargs: Any) -> CopilotsdkProvider: + """Instantiate CopilotsdkProvider while bypassing real CLI startup.""" + provider = object.__new__(CopilotsdkProvider) + provider._resolved_token = kwargs.get("api_key") + provider._cli_url = kwargs.get("api_base") + provider._cli_path = None + provider._extra_kwargs = {} + provider._copilot_client = None + provider._client_lock = asyncio.Lock() + return provider def _make_model_info(model_id: str = "gpt-4o", name: str = "GPT-4o") -> Any: """Return a minimal mock ModelInfo.""" - m = MagicMock() - m.id = model_id - m.name = name - return m + model_info = MagicMock() + model_info.id = model_id + model_info.name = name + return model_info -# --------------------------------------------------------------------------- -# _messages_to_prompt -# --------------------------------------------------------------------------- +def _make_session(send_and_wait_result: Any = None, send_and_wait_error: Exception | None = None) -> Any: + """Build a mock Copilot session with a sync on() and async methods.""" + session = MagicMock() + session.on = MagicMock(return_value=lambda: None) + session.disconnect = AsyncMock() + if send_and_wait_error: + session.send_and_wait = AsyncMock(side_effect=send_and_wait_error) + else: + session.send_and_wait = AsyncMock(return_value=send_and_wait_result) + return session + + +def _make_data_uri(mime: str = "image/jpeg", content: bytes = b"\xff\xd8\xff") -> str: + return f"data:{mime};base64,{base64.b64encode(content).decode()}" + def test_messages_to_prompt_simple_user() -> None: - msgs = [{"role": "user", "content": "Hello"}] - assert _messages_to_prompt(msgs) == "User: Hello" + messages = [{"role": "user", "content": "Hello"}] + assert _messages_to_prompt(messages) == "User: Hello" def test_messages_to_prompt_system_prepended() -> None: - msgs = [ + messages = [ {"role": "system", "content": "Be concise."}, {"role": "user", "content": "Hi"}, ] - result = _messages_to_prompt(msgs) + result = _messages_to_prompt(messages) assert result.startswith("Be concise.") assert "User: Hi" in result def test_messages_to_prompt_multi_turn() -> None: - msgs = [ + messages = [ {"role": "user", "content": "Ping"}, {"role": "assistant", "content": "Pong"}, {"role": "user", "content": "Again"}, ] - result = _messages_to_prompt(msgs) + result = _messages_to_prompt(messages) assert "User: Ping" in result assert "Assistant: Pong" in result assert "User: Again" in result def test_messages_to_prompt_multimodal_extracts_text() -> None: - msgs = [{"role": "user", "content": [{"type": "text", "text": "Tell me"}, {"type": "image_url"}]}] - result = _messages_to_prompt(msgs) + messages = [{"role": "user", "content": [{"type": "text", "text": "Tell me"}, {"type": "image_url"}]}] + result = _messages_to_prompt(messages) assert "Tell me" in result -# --------------------------------------------------------------------------- -# _build_chat_completion -# --------------------------------------------------------------------------- +def test_messages_to_prompt_empty_list() -> None: + assert _messages_to_prompt([]) == "" + def test_build_chat_completion_structure() -> None: - cc = _build_chat_completion("Hello world", "gpt-4o") - assert cc.choices[0].message.content == "Hello world" - assert cc.model == "gpt-4o" - assert cc.choices[0].finish_reason == "stop" - assert cc.object == "chat.completion" + completion = _build_chat_completion("Hello world", "gpt-4o") + assert completion.choices[0].message.content == "Hello world" + assert completion.model == "gpt-4o" + assert completion.choices[0].finish_reason == "stop" + assert completion.object == "chat.completion" + + +def test_build_chat_completion_with_reasoning() -> None: + completion = _build_chat_completion("Answer", "gpt-4o", reasoning="Because math.") + assert completion.choices[0].message.reasoning is not None + assert completion.choices[0].message.reasoning.content == "Because math." -# --------------------------------------------------------------------------- -# _copilot_model_to_openai -# --------------------------------------------------------------------------- +def test_build_chat_completion_no_reasoning() -> None: + completion = _build_chat_completion("Answer", "gpt-4o") + assert completion.choices[0].message.reasoning is None + + +def test_build_chunk_content_delta() -> None: + chunk = _build_chunk("hello", "gpt-4o", is_reasoning=False) + assert chunk.choices[0].delta.content == "hello" + assert chunk.choices[0].delta.reasoning is None + assert chunk.object == "chat.completion.chunk" + + +def test_build_chunk_reasoning_delta() -> None: + chunk = _build_chunk("step 1", "gpt-4o", is_reasoning=True) + assert chunk.choices[0].delta.content is None + assert chunk.choices[0].delta.reasoning is not None + assert chunk.choices[0].delta.reasoning.content == "step 1" + def test_copilot_model_to_openai_maps_id() -> None: info = _make_model_info("claude-sonnet-4-5", "Claude Sonnet 4.5") @@ -106,89 +138,116 @@ def test_copilot_model_to_openai_maps_id() -> None: assert model.object == "model" -# --------------------------------------------------------------------------- -# Provider class attributes -# --------------------------------------------------------------------------- - def test_provider_required_attributes() -> None: - assert CopilotSdkProvider.PROVIDER_NAME == "copilot_sdk" - assert CopilotSdkProvider.SUPPORTS_COMPLETION is True - assert CopilotSdkProvider.SUPPORTS_LIST_MODELS is True - assert CopilotSdkProvider.SUPPORTS_EMBEDDING is False - assert CopilotSdkProvider.SUPPORTS_COMPLETION_STREAMING is True - assert CopilotSdkProvider.SUPPORTS_COMPLETION_REASONING is True - assert CopilotSdkProvider.SUPPORTS_COMPLETION_IMAGE is True - assert CopilotSdkProvider.MISSING_PACKAGES_ERROR is None + assert CopilotsdkProvider.PROVIDER_NAME == "copilotsdk" + assert CopilotsdkProvider.SUPPORTS_COMPLETION is True + assert CopilotsdkProvider.SUPPORTS_LIST_MODELS is True + assert CopilotsdkProvider.SUPPORTS_EMBEDDING is False + assert CopilotsdkProvider.SUPPORTS_COMPLETION_STREAMING is True + assert CopilotsdkProvider.SUPPORTS_COMPLETION_REASONING is True + assert CopilotsdkProvider.SUPPORTS_COMPLETION_IMAGE is True + assert CopilotsdkProvider.MISSING_PACKAGES_ERROR is None -# --------------------------------------------------------------------------- -# Auth: _verify_and_set_api_key -# --------------------------------------------------------------------------- - def test_api_key_explicit_wins() -> None: - p = object.__new__(CopilotSdkProvider) + provider = object.__new__(CopilotsdkProvider) with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "env-token"}, clear=False): - result = p._verify_and_set_api_key("explicit-token") + result = provider._verify_and_set_api_key("explicit-token") assert result == "explicit-token" def test_api_key_falls_back_to_copilot_env() -> None: - p = object.__new__(CopilotSdkProvider) + provider = object.__new__(CopilotsdkProvider) with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "copilot-env"}, clear=False): - result = p._verify_and_set_api_key(None) + result = provider._verify_and_set_api_key(None) assert result == "copilot-env" def test_api_key_falls_back_to_github_token() -> None: - p = object.__new__(CopilotSdkProvider) + provider = object.__new__(CopilotsdkProvider) env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "gh-token"} with patch.dict(os.environ, env, clear=False): - result = p._verify_and_set_api_key(None) + result = provider._verify_and_set_api_key(None) assert result == "gh-token" def test_api_key_returns_none_when_all_absent() -> None: - """No token at all → None (triggers logged-in CLI user mode).""" - p = object.__new__(CopilotSdkProvider) + provider = object.__new__(CopilotsdkProvider) env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "", "GH_TOKEN": ""} with patch.dict(os.environ, env, clear=False): - result = p._verify_and_set_api_key(None) + result = provider._verify_and_set_api_key(None) assert result is None -# --------------------------------------------------------------------------- -# _init_client -# --------------------------------------------------------------------------- - def test_init_client_stores_token_and_url() -> None: - p = object.__new__(CopilotSdkProvider) - p._init_client(api_key="my-token", api_base="localhost:9000") - assert p._resolved_token == "my-token" - assert p._cli_url == "localhost:9000" - assert p._copilot_client is None + provider = object.__new__(CopilotsdkProvider) + provider._init_client(api_key="my-token", api_base="localhost:9000") + assert provider._resolved_token == "my-token" + assert provider._cli_url == "localhost:9000" + assert provider._copilot_client is None def test_init_client_reads_cli_url_from_env() -> None: - p = object.__new__(CopilotSdkProvider) + provider = object.__new__(CopilotsdkProvider) with patch.dict(os.environ, {"COPILOT_CLI_URL": "localhost:7777"}, clear=False): - p._init_client(api_key=None, api_base=None) - assert p._cli_url == "localhost:7777" + provider._init_client(api_key=None, api_base=None) + assert provider._cli_url == "localhost:7777" -# --------------------------------------------------------------------------- -# _acompletion (async) -# --------------------------------------------------------------------------- +def test_extract_attachments_base64_image() -> None: + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + {"type": "image_url", "image_url": {"url": _make_data_uri()}}, + ], + }] + attachments, temp_paths = _extract_attachments(messages) + try: + assert len(attachments) == 1 + assert attachments[0]["type"] == "file" + assert os.path.exists(attachments[0]["path"]) + assert len(temp_paths) == 1 + finally: + _cleanup_temp_files(temp_paths) + + +def test_extract_attachments_http_url_skipped() -> None: + messages = [{ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}], + }] + attachments, temp_paths = _extract_attachments(messages) + assert attachments == [] + assert temp_paths == [] + + +def test_extract_attachments_no_images() -> None: + attachments, temp_paths = _extract_attachments([{"role": "user", "content": "Plain text, no images."}]) + assert attachments == [] + assert temp_paths == [] -def _make_session(send_and_wait_result: Any = None, send_and_wait_error: Exception | None = None) -> Any: - """Build a mock CopilotSession with a sync on() and async send_and_wait/disconnect.""" - session = MagicMock() - session.on = MagicMock(return_value=lambda: None) # sync; returns unsubscribe callable - session.disconnect = AsyncMock() - if send_and_wait_error: - session.send_and_wait = AsyncMock(side_effect=send_and_wait_error) - else: - session.send_and_wait = AsyncMock(return_value=send_and_wait_result) - return session + +def test_extract_attachments_malformed_data_uri_skipped() -> None: + messages = [{ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,NOT_VALID!!"}}], + }] + attachments, temp_paths = _extract_attachments(messages) + assert attachments == [] + assert temp_paths == [] + + +def test_cleanup_temp_files_removes_files() -> None: + with tempfile.NamedTemporaryFile(delete=False) as handle: + path = handle.name + assert os.path.exists(path) + _cleanup_temp_files([path]) + assert not os.path.exists(path) + + +def test_cleanup_temp_files_tolerates_missing() -> None: + _cleanup_temp_files(["/tmp/does-not-exist-copilotsdk-test-xyz"]) @pytest.mark.asyncio @@ -216,7 +275,6 @@ async def test_acompletion_returns_chat_completion() -> None: @pytest.mark.asyncio async def test_acompletion_handles_none_event() -> None: - """When send_and_wait returns None, content is empty string.""" provider = _make_provider(api_key="test-token") mock_session = _make_session(send_and_wait_result=None) @@ -235,7 +293,6 @@ async def test_acompletion_handles_none_event() -> None: @pytest.mark.asyncio async def test_acompletion_disconnects_session_on_error() -> None: - """Session.disconnect() is called even when send_and_wait raises.""" provider = _make_provider(api_key="test-token") mock_session = _make_session(send_and_wait_error=RuntimeError("CLI died")) @@ -253,10 +310,6 @@ async def test_acompletion_disconnects_session_on_error() -> None: mock_session.disconnect.assert_called_once() -# --------------------------------------------------------------------------- -# _alist_models (async) -# --------------------------------------------------------------------------- - @pytest.mark.asyncio async def test_alist_models_converts_model_infos() -> None: provider = _make_provider(api_key="test-token") @@ -272,24 +325,18 @@ async def test_alist_models_converts_model_infos() -> None: models = await provider._alist_models() assert len(models) == 2 - ids = {m.id for m in models} - assert ids == {"gpt-4o", "claude-sonnet-4-5"} - for m in models: - assert m.owned_by == "github-copilot" - + assert {model.id for model in models} == {"gpt-4o", "claude-sonnet-4-5"} + for model in models: + assert model.owned_by == "github-copilot" -# --------------------------------------------------------------------------- -# _ensure_client (async) -# --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_ensure_client_reuses_existing_instance() -> None: - """Second call returns the same client without calling start() again.""" provider = _make_provider(api_key="test-token") mock_client = AsyncMock() mock_client.start = AsyncMock() - provider._copilot_client = mock_client # pre-seed as already initialized + provider._copilot_client = mock_client result = await provider._ensure_client() @@ -299,13 +346,12 @@ async def test_ensure_client_reuses_existing_instance() -> None: @pytest.mark.asyncio async def test_ensure_client_uses_cli_url_path() -> None: - """When _cli_url is set, only cli_url is passed (no token or cli_path).""" provider = _make_provider(api_key="tok", api_base="localhost:9000") mock_client = AsyncMock() mock_client.start = AsyncMock() - with patch("any_llm.providers.copilot_sdk.copilot_sdk.CopilotClient", return_value=mock_client) as mock_cls: + with patch("any_llm.providers.copilotsdk.copilotsdk.CopilotClient", return_value=mock_client) as mock_cls: result = await provider._ensure_client() assert result is mock_client @@ -313,138 +359,27 @@ async def test_ensure_client_uses_cli_url_path() -> None: mock_client.start.assert_called_once() -# --------------------------------------------------------------------------- -# _messages_to_prompt edge cases -# --------------------------------------------------------------------------- - -def test_messages_to_prompt_empty_list() -> None: - """Empty messages list returns an empty string without raising.""" - assert _messages_to_prompt([]) == "" - - -# --------------------------------------------------------------------------- -# _build_chat_completion — reasoning field -# --------------------------------------------------------------------------- - -def test_build_chat_completion_with_reasoning() -> None: - cc = _build_chat_completion("Answer", "gpt-4o", reasoning="Because math.") - assert cc.choices[0].message.reasoning is not None - assert cc.choices[0].message.reasoning.content == "Because math." - - -def test_build_chat_completion_no_reasoning() -> None: - cc = _build_chat_completion("Answer", "gpt-4o") - assert cc.choices[0].message.reasoning is None - - -# --------------------------------------------------------------------------- -# _build_chunk -# --------------------------------------------------------------------------- - -def test_build_chunk_content_delta() -> None: - chunk = _build_chunk("hello", "gpt-4o", is_reasoning=False) - assert chunk.choices[0].delta.content == "hello" - assert chunk.choices[0].delta.reasoning is None - assert chunk.object == "chat.completion.chunk" - - -def test_build_chunk_reasoning_delta() -> None: - chunk = _build_chunk("step 1", "gpt-4o", is_reasoning=True) - assert chunk.choices[0].delta.content is None - assert chunk.choices[0].delta.reasoning is not None - assert chunk.choices[0].delta.reasoning.content == "step 1" - - -# --------------------------------------------------------------------------- -# _extract_attachments -# --------------------------------------------------------------------------- - -def _make_data_uri(mime: str = "image/jpeg", content: bytes = b"\xff\xd8\xff") -> str: - return f"data:{mime};base64,{base64.b64encode(content).decode()}" - - -def test_extract_attachments_base64_image() -> None: - msgs = [{"role": "user", "content": [ - {"type": "text", "text": "Look at this"}, - {"type": "image_url", "image_url": {"url": _make_data_uri()}}, - ]}] - attachments, temp_paths = _extract_attachments(msgs) - try: - assert len(attachments) == 1 - assert attachments[0]["type"] == "file" - assert os.path.exists(attachments[0]["path"]) - assert len(temp_paths) == 1 - finally: - _cleanup_temp_files(temp_paths) - - -def test_extract_attachments_http_url_skipped() -> None: - """HTTP image URLs are not supported (no download) and must be silently skipped.""" - msgs = [{"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, - ]}] - attachments, temp_paths = _extract_attachments(msgs) - assert attachments == [] - assert temp_paths == [] - - -def test_extract_attachments_no_images() -> None: - msgs = [{"role": "user", "content": "Plain text, no images."}] - attachments, temp_paths = _extract_attachments(msgs) - assert attachments == [] - assert temp_paths == [] - - -def test_extract_attachments_malformed_data_uri_skipped() -> None: - msgs = [{"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,NOT_VALID!!"}}, - ]}] - # Should not raise; malformed URIs are silently dropped. - attachments, temp_paths = _extract_attachments(msgs) - assert attachments == [] - - -def test_cleanup_temp_files_removes_files() -> None: - import tempfile - with tempfile.NamedTemporaryFile(delete=False) as fh: - path = fh.name - assert os.path.exists(path) - _cleanup_temp_files([path]) - assert not os.path.exists(path) - - -def test_cleanup_temp_files_tolerates_missing() -> None: - """Cleaning up a path that doesn't exist must not raise.""" - _cleanup_temp_files(["/tmp/does-not-exist-copilot-sdk-test-xyz"]) - - -# --------------------------------------------------------------------------- -# _acompletion — reasoning captured in non-streaming mode -# --------------------------------------------------------------------------- - @pytest.mark.asyncio async def test_acompletion_captures_reasoning() -> None: - """assistant.reasoning event content is surfaced in ChatCompletion.reasoning.""" from copilot.generated.session_events import SessionEventType provider = _make_provider(api_key="test-token") - mock_msg_event = MagicMock() - mock_msg_event.data.content = "42" - mock_msg_event.type = SessionEventType.ASSISTANT_MESSAGE + message_event = MagicMock() + message_event.data.content = "42" + message_event.type = SessionEventType.ASSISTANT_MESSAGE - mock_reasoning_event = MagicMock() - mock_reasoning_event.data.content = "Because 6×7=42." - mock_reasoning_event.type = SessionEventType.ASSISTANT_REASONING + reasoning_event = MagicMock() + reasoning_event.data.content = "Because 6×7=42." + reasoning_event.type = SessionEventType.ASSISTANT_REASONING def fake_on(callback: Any) -> Any: - """Immediately fire ASSISTANT_REASONING, then return a no-op unsubscribe.""" - callback(mock_reasoning_event) + callback(reasoning_event) return lambda: None mock_session = MagicMock() mock_session.on = MagicMock(side_effect=fake_on) - mock_session.send_and_wait = AsyncMock(return_value=mock_msg_event) + mock_session.send_and_wait = AsyncMock(return_value=message_event) mock_session.disconnect = AsyncMock() mock_client = AsyncMock() @@ -462,10 +397,6 @@ def fake_on(callback: Any) -> Any: assert result.choices[0].message.reasoning.content == "Because 6×7=42." -# --------------------------------------------------------------------------- -# _acompletion — reasoning_effort passed to session_cfg -# --------------------------------------------------------------------------- - @pytest.mark.asyncio async def test_acompletion_passes_reasoning_effort_to_session() -> None: provider = _make_provider(api_key="test-token") @@ -488,7 +419,6 @@ async def test_acompletion_passes_reasoning_effort_to_session() -> None: @pytest.mark.asyncio async def test_acompletion_omits_auto_reasoning_effort() -> None: - """reasoning_effort='auto' must NOT be forwarded (SDK doesn't accept it).""" provider = _make_provider(api_key="test-token") mock_session = _make_session() @@ -507,33 +437,28 @@ async def test_acompletion_omits_auto_reasoning_effort() -> None: assert "reasoning_effort" not in call_cfg -# --------------------------------------------------------------------------- -# _acompletion — streaming -# --------------------------------------------------------------------------- - @pytest.mark.asyncio async def test_acompletion_streaming_yields_chunks() -> None: from copilot.generated.session_events import SessionEventType provider = _make_provider(api_key="test-token") - # Build fake events that the on() callback will receive. - def _evt(etype: SessionEventType, delta: str) -> MagicMock: - e = MagicMock() - e.type = etype - e.data.delta_content = delta - return e + def make_event(event_type: SessionEventType, delta: str) -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.delta_content = delta + return event delta_events = [ - _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, "Hello"), - _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, " world"), - _evt(SessionEventType.SESSION_IDLE, ""), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, "Hello"), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, " world"), + make_event(SessionEventType.SESSION_IDLE, ""), ] registered_callback: list[Any] = [] - def fake_on(cb: Any) -> Any: - registered_callback.append(cb) + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) return lambda: None mock_session = MagicMock() @@ -541,8 +466,8 @@ def fake_on(cb: Any) -> Any: mock_session.disconnect = AsyncMock() async def fake_send(opts: Any) -> str: - for evt in delta_events: - registered_callback[0](evt) + for event in delta_events: + registered_callback[0](event) return "msg-id" mock_session.send = fake_send @@ -557,7 +482,7 @@ async def fake_send(opts: Any) -> str: stream=True, ) stream = await provider._acompletion(params) - chunks = [c async for c in stream] + chunks = [chunk async for chunk in stream] assert len(chunks) == 2 assert chunks[0].choices[0].delta.content == "Hello" @@ -571,22 +496,22 @@ async def test_acompletion_streaming_reasoning_chunks() -> None: provider = _make_provider(api_key="test-token") - def _evt(etype: SessionEventType, delta: str) -> MagicMock: - e = MagicMock() - e.type = etype - e.data.delta_content = delta - return e + def make_event(event_type: SessionEventType, delta: str) -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.delta_content = delta + return event delta_events = [ - _evt(SessionEventType.ASSISTANT_REASONING_DELTA, "step 1"), - _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, "answer"), - _evt(SessionEventType.SESSION_IDLE, ""), + make_event(SessionEventType.ASSISTANT_REASONING_DELTA, "step 1"), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, "answer"), + make_event(SessionEventType.SESSION_IDLE, ""), ] registered_callback: list[Any] = [] - def fake_on(cb: Any) -> Any: - registered_callback.append(cb) + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) return lambda: None mock_session = MagicMock() @@ -594,8 +519,8 @@ def fake_on(cb: Any) -> Any: mock_session.disconnect = AsyncMock() async def fake_send(opts: Any) -> str: - for evt in delta_events: - registered_callback[0](evt) + for event in delta_events: + registered_callback[0](event) return "msg-id" mock_session.send = fake_send @@ -610,43 +535,39 @@ async def fake_send(opts: Any) -> str: stream=True, ) stream = await provider._acompletion(params) - chunks = [c async for c in stream] + chunks = [chunk async for chunk in stream] assert len(chunks) == 2 - # First chunk carries reasoning assert chunks[0].choices[0].delta.reasoning is not None assert chunks[0].choices[0].delta.reasoning.content == "step 1" assert chunks[0].choices[0].delta.content is None - # Second chunk carries content assert chunks[1].choices[0].delta.content == "answer" assert chunks[1].choices[0].delta.reasoning is None @pytest.mark.asyncio async def test_acompletion_streaming_session_error_raises() -> None: - """SESSION_ERROR event must raise RuntimeError rather than completing silently.""" from copilot.generated.session_events import SessionEventType provider = _make_provider(api_key="test-token") - def _evt(etype: SessionEventType, msg: str = "") -> MagicMock: - e = MagicMock() - e.type = etype - e.data.message = msg - e.data.delta_content = "" - return e + def make_event(event_type: SessionEventType, message: str = "") -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.message = message + event.data.delta_content = "" + return event delta_events = [ - _evt(SessionEventType.ASSISTANT_MESSAGE_DELTA, ""), # one content delta first - _evt(SessionEventType.SESSION_ERROR, "CLI crashed"), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA), + make_event(SessionEventType.SESSION_ERROR, "CLI crashed"), ] - # Override delta_content for the first event delta_events[0].data.delta_content = "partial" registered_callback: list[Any] = [] - def fake_on(cb: Any) -> Any: - registered_callback.append(cb) + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) return lambda: None mock_session = MagicMock() @@ -654,8 +575,8 @@ def fake_on(cb: Any) -> Any: mock_session.disconnect = AsyncMock() async def fake_send(opts: Any) -> str: - for evt in delta_events: - registered_callback[0](evt) + for event in delta_events: + registered_callback[0](event) return "msg-id" mock_session.send = fake_send @@ -679,9 +600,6 @@ async def fake_send(opts: Any) -> str: @pytest.mark.asyncio async def test_acompletion_image_attachments_forwarded() -> None: - """Image attachments extracted from messages must be included in msg_opts sent to session.""" - import base64 - provider = _make_provider(api_key="test-token") jpeg_data = base64.b64encode(b"\xff\xd8\xff").decode() @@ -707,8 +625,7 @@ async def test_acompletion_image_attachments_forwarded() -> None: result = await provider._acompletion(params) assert result.choices[0].message.content == "I see an image." - # send_and_wait must have received msg_opts with 'attachments' call_kwargs = mock_session.send_and_wait.call_args[0][0] assert "attachments" in call_kwargs assert len(call_kwargs["attachments"]) == 1 - assert call_kwargs["attachments"][0]["type"] == "file" + assert call_kwargs["attachments"][0]["type"] == "file" \ No newline at end of file diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index 35bdb29e..75038550 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -149,7 +149,7 @@ def test_providers_raise_MissingApiKeyError(provider: LLMProvider) -> None: LLMProvider.VERTEXAI, LLMProvider.VLLM, LLMProvider.GATEWAY, - LLMProvider.COPILOT_SDK, # uses CLI auth; no API key required + LLMProvider.COPILOTSDK, # uses CLI auth; no API key required ): pytest.skip("This provider handles `api_key` differently.") with patch.dict(os.environ, {}, clear=True):