diff --git a/docs/src/content/docs/providers.md b/docs/src/content/docs/providers.md index 11938b4f..cc882a96 100644 --- a/docs/src/content/docs/providers.md +++ b/docs/src/content/docs/providers.md @@ -27,6 +27,7 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`bedrock`](https://aws.amazon.com/bedrock/) | AWS_BEARER_TOKEN_BEDROCK | AWS_ENDPOINT_URL_BEDROCK_RUNTIME | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [`cerebras`](https://docs.cerebras.ai/) | CEREBRAS_API_KEY | CEREBRAS_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`cohere`](https://cohere.com/api) | COHERE_API_KEY | COHERE_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | +| [`copilotsdk`](https://github.com/github/copilot-sdk) | COPILOT_GITHUB_TOKEN | COPILOT_CLI_URL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | [`databricks`](https://docs.databricks.com/) | DATABRICKS_TOKEN | DATABRICKS_HOST | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | | [`deepseek`](https://platform.deepseek.com/) | DEEPSEEK_API_KEY | DEEPSEEK_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`fireworks`](https://fireworks.ai/api) | FIREWORKS_API_KEY | FIREWORKS_API_BASE | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | @@ -60,3 +61,81 @@ Provider source code is in [`src/any_llm/providers/`](https://github.com/mozilla | [`xai`](https://x.ai/) | XAI_API_KEY | XAI_API_BASE | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | [`zai`](https://docs.z.ai/guides/develop/python/introduction) | ZAI_API_KEY | ZAI_BASE_URL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | + +## Provider Notes + +### `copilotsdk` — GitHub Copilot SDK + +The `copilotsdk` provider communicates with GitHub Copilot models via the +[`github-copilot-sdk`](https://pypi.org/project/github-copilot-sdk/) Python package, +which bundles the Copilot CLI binary for your platform. + +#### Installation + +Install the platform-specific wheel: + +```bash +pip install any-llm-sdk[copilotsdk] +``` + +> **Note**: `github-copilot-sdk` ships separate wheels per OS and CPU architecture +> (e.g. `macosx_arm64`, `linux_x86_64`). `pip` selects the correct wheel automatically +> on supported platforms. If installation fails, check [PyPI](https://pypi.org/project/github-copilot-sdk/#files) +> for available platform tags. + +#### Authentication + +Two modes are supported, checked in order: + +1. **Token mode** — set one of these environment variables: + + ```bash + export COPILOT_GITHUB_TOKEN="ghp_your_token" + # or + export GITHUB_TOKEN="ghp_your_token" + # or + export GH_TOKEN="ghp_your_token" + ``` + + Alternatively, pass `api_key` directly to `AnyLLM.create()`. + +2. **Logged-in CLI user** — if no token is set, the provider uses the credentials + from the local `gh` / `copilot` CLI session (run `gh auth login` first). No + environment variable is required in this mode. + +#### Configuration + +| Environment Variable | Purpose | Default | +| --- | --- | --- | +| `COPILOT_GITHUB_TOKEN` | GitHub token with Copilot access | — | +| `GITHUB_TOKEN` / `GH_TOKEN` | Fallback token sources | — | +| `COPILOT_CLI_URL` | Connect to an external CLI server instead of spawning one (e.g. `localhost:9000`) | — | +| `COPILOT_CLI_PATH` | Override the Copilot CLI binary path | PATH lookup | + +#### Usage + +```python +from any_llm import AnyLLM + +# Token auth (or set COPILOT_GITHUB_TOKEN in environment) +llm = AnyLLM.create("copilotsdk") + +# List available models +models = llm.list_models() + +# Completion with reasoning +response = llm.completion( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Explain async generators in Python."}], + reasoning_effort="high", +) +print(response.choices[0].message.content) + +# Streaming +for chunk in llm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello!"}], + stream=True, +): + print(chunk.choices[0].delta.content or "", end="", flush=True) +``` diff --git a/pyproject.toml b/pyproject.toml index 783677a4..4a125c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] all = [ - "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureanthropic,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm]" + "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,vertexaianthropic,cohere,cerebras,fireworks,groq,bedrock,azure,azureanthropic,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,mzai,vllm,copilotsdk]" ] platform = [ @@ -31,6 +31,10 @@ platform = [ perplexity = [] +copilotsdk = [ + "github-copilot-sdk>=0.1.0", +] + mistral = [ "mistralai>=1.9.3", ] diff --git a/src/any_llm/constants.py b/src/any_llm/constants.py index efcf6035..584613e5 100644 --- a/src/any_llm/constants.py +++ b/src/any_llm/constants.py @@ -54,6 +54,7 @@ class LLMProvider(StrEnum): PERPLEXITY = "perplexity" MINIMAX = "minimax" ZAI = "zai" + COPILOTSDK = "copilotsdk" GATEWAY = "gateway" @classmethod diff --git a/src/any_llm/providers/copilotsdk/__init__.py b/src/any_llm/providers/copilotsdk/__init__.py new file mode 100644 index 00000000..2f6f2ec5 --- /dev/null +++ b/src/any_llm/providers/copilotsdk/__init__.py @@ -0,0 +1,3 @@ +from .copilotsdk import CopilotsdkProvider + +__all__ = ["CopilotsdkProvider"] diff --git a/src/any_llm/providers/copilotsdk/copilotsdk.py b/src/any_llm/providers/copilotsdk/copilotsdk.py new file mode 100644 index 00000000..8611ffcb --- /dev/null +++ b/src/any_llm/providers/copilotsdk/copilotsdk.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +import asyncio +import base64 +import mimetypes +import os +import tempfile +import time +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from any_llm.any_llm import AnyLLM + +# Eagerly initialise the MIME-type database so that mimetypes.guess_extension() +# is thread-safe from the very first call (lazy initialisation is not thread-safe). +mimetypes.init() + +MISSING_PACKAGES_ERROR: ImportError | None = None +try: + from copilot import CopilotClient, PermissionHandler + from copilot.generated.session_events import SessionEventType + from copilot.types import ModelInfo as CopilotModelInfo +except ImportError as e: + MISSING_PACKAGES_ERROR = e + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Sequence + + from any_llm.types.completion import ( + ChatCompletion, + ChatCompletionChunk, + CompletionParams, + CreateEmbeddingResponse, + ) + from any_llm.types.model import Model + + +def _messages_to_prompt(messages: list[dict[str, Any]]) -> str: + """Flatten an OpenAI-style messages list into a single prompt string. + + System messages become an instruction header; prior conversational turns + are formatted as a transcript; the final user message is the prompt. + Image content blocks are intentionally omitted here — they are forwarded + as file attachments via :func:`_extract_attachments`. + """ + system_parts: list[str] = [] + conversation_parts: list[str] = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + # Multimodal content blocks: extract text only; images handled separately. + if isinstance(content, list): + content = " ".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + + if role == "system": + system_parts.append(str(content)) + elif role == "assistant": + conversation_parts.append(f"Assistant: {content}") + else: + conversation_parts.append(f"User: {content}") + + parts: list[str] = [] + if system_parts: + parts.append("\n".join(system_parts)) + if conversation_parts: + parts.append("\n\n".join(conversation_parts)) + return "\n\n".join(parts) + + +def _extract_attachments(messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[str]]: + """Extract ``image_url`` content blocks from messages as file attachments. + + Decodes base64 ``data:`` URIs to temporary files on disk, which the + Copilot SDK then passes to the CLI as ``FileAttachment`` objects. + + Returns: + (attachments, temp_paths) where ``temp_paths`` must be cleaned up + by the caller after the ``session.send()`` call completes. + """ + attachments: list[dict[str, Any]] = [] + temp_paths: list[str] = [] + + for msg in messages: + content = msg.get("content", "") + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict) or block.get("type") != "image_url": + continue + url = (block.get("image_url") or {}).get("url", "") + if not url.startswith("data:"): + # HTTP/HTTPS URLs would need downloading — skipped for now. + continue + try: + header, b64data = url.split(",", 1) + mime = header.split(";")[0].split(":")[1] + ext = mimetypes.guess_extension(mime) or ".bin" + # guess_extension returns ".jpe" for image/jpeg on some platforms. + if ext == ".jpe": + ext = ".jpg" + raw = base64.b64decode(b64data, validate=True) + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as fh: + fh.write(raw) + temp_paths.append(fh.name) + attachments.append({"type": "file", "path": fh.name}) + except Exception: # noqa: BLE001 + pass # Skip malformed data URIs rather than failing the whole request. + + return attachments, temp_paths + + +def _cleanup_temp_files(paths: list[str]) -> None: + """Remove temporary image files created by :func:`_extract_attachments`.""" + for path in paths: + try: + os.unlink(path) + except OSError: + pass + + +def _build_chat_completion( + content: str, + model_id: str, + reasoning: str | None = None, +) -> "ChatCompletion": + """Wrap a plain text response into an OpenAI-compatible ChatCompletion.""" + from any_llm.types.completion import ( # noqa: PLC0415 + ChatCompletion, + ChatCompletionMessage, + Choice, + Reasoning, + ) + + return ChatCompletion( + id=f"copilotsdk-{time.time_ns()}", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content=content, + reasoning=Reasoning(content=reasoning) if reasoning else None, + ), + logprobs=None, + ) + ], + created=int(time.time()), + model=model_id, + object="chat.completion", + ) + + +def _build_chunk(delta: str, model_id: str, *, is_reasoning: bool = False) -> "ChatCompletionChunk": + """Wrap a streaming delta into an OpenAI-compatible ChatCompletionChunk.""" + from any_llm.types.completion import ( # noqa: PLC0415 + ChatCompletionChunk, + ChunkChoice, + ChoiceDelta, + Reasoning, + ) + + return ChatCompletionChunk( + id=f"copilotsdk-{time.time_ns()}", + choices=[ + ChunkChoice( + delta=ChoiceDelta( + role="assistant", + content=None if is_reasoning else delta, + reasoning=Reasoning(content=delta) if is_reasoning else None, + ), + finish_reason=None, + index=0, + ) + ], + created=int(time.time()), + model=model_id, + object="chat.completion.chunk", + ) + + +def _copilot_model_to_openai(info: "CopilotModelInfo") -> "Model": + """Convert a copilot-sdk ModelInfo to an OpenAI-compatible Model.""" + from openai.types.model import Model as OpenAIModel # noqa: PLC0415 + + return OpenAIModel(id=info.id, created=0, owned_by="github-copilot", object="model") + + +class CopilotsdkProvider(AnyLLM): + """GitHub Copilot SDK provider for any-llm. + + Communicates with the Copilot CLI via JSON-RPC using the ``github-copilot-sdk`` + Python package. Authentication supports two modes (checked in order): + + 1. **Token mode** — set ``COPILOT_GITHUB_TOKEN``, ``GITHUB_TOKEN``, or + ``GH_TOKEN`` in the environment (or pass ``api_key`` explicitly). + 2. **Logged-in CLI user** — if no token is found, the Copilot CLI uses + the credentials from the local user's ``gh`` / ``copilot`` CLI session. + + Supports completion (streaming and non-streaming), reasoning + (``reasoning_effort``), image attachments via ``data:`` URIs, and model + listing. The binary is bundled by the ``github-copilot-sdk`` wheel; + ``COPILOT_CLI_URL`` overrides to an external server. + + Environment variables: + COPILOT_GITHUB_TOKEN: GitHub token with Copilot access (optional). + GITHUB_TOKEN / GH_TOKEN: Fallback token sources (optional). + COPILOT_CLI_URL: Connect to an external CLI server instead of spawning one. + COPILOT_CLI_PATH: Override the CLI binary path (default: PATH lookup). + """ + + PROVIDER_NAME = "copilotsdk" + ENV_API_KEY_NAME = "COPILOT_GITHUB_TOKEN" + ENV_API_BASE_NAME = "COPILOT_CLI_URL" + PROVIDER_DOCUMENTATION_URL = "https://github.com/github/copilot-sdk" + + SUPPORTS_COMPLETION = True + SUPPORTS_COMPLETION_STREAMING = True + SUPPORTS_COMPLETION_REASONING = True + SUPPORTS_COMPLETION_IMAGE = True + SUPPORTS_COMPLETION_PDF = False + SUPPORTS_EMBEDDING = False + SUPPORTS_RESPONSES = False + SUPPORTS_LIST_MODELS = True + SUPPORTS_BATCH = False + + MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + + # Internal state — populated lazily on first call. + _copilot_client: "CopilotClient | None" + + @override + def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: + """API key is optional: logged-in CLI mode works without any token.""" + resolved = ( + api_key + or os.getenv("COPILOT_GITHUB_TOKEN") + or os.getenv("GITHUB_TOKEN") + or os.getenv("GH_TOKEN") + ) + # Return None (not empty string) so copilot-sdk uses logged-in credentials. + return resolved or None + + @override + def _init_client(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None: + """Store resolved auth options; actual CLI start is deferred to the first async call.""" + self._resolved_token: str | None = api_key + self._cli_url: str | None = api_base or os.getenv("COPILOT_CLI_URL") + self._cli_path: str | None = os.getenv("COPILOT_CLI_PATH") or None + self._extra_kwargs = kwargs + self._copilot_client = None + self._client_lock = asyncio.Lock() + + async def _ensure_client(self) -> "CopilotClient": + """Lazily create and start a CopilotClient, reusing it across calls. + + The lock prevents concurrent coroutines from each spawning a separate + CLI process when the client has not yet been initialized. + """ + if self._copilot_client is not None: + return self._copilot_client + + async with self._client_lock: + # Re-check inside the lock: another coroutine may have initialized + # the client while we were waiting. + if self._copilot_client is not None: + return self._copilot_client + + opts: dict[str, Any] = {} + if self._cli_url: + opts["cli_url"] = self._cli_url + else: + if self._cli_path: + opts["cli_path"] = self._cli_path + if self._resolved_token: + opts["github_token"] = self._resolved_token + + # Pass None (not {}) so the SDK uses its own defaults when no + # options are configured. + self._copilot_client = CopilotClient(opts or None) + await self._copilot_client.start() + + return self._copilot_client + + def _build_session_cfg(self, params: "CompletionParams", streaming: bool) -> dict[str, Any]: + """Build a SessionConfig dict from CompletionParams.""" + # PermissionHandler.approve_all silently grants any permissions the session + # requests (e.g. tool calls). This mirrors how the Copilot CLI behaves in + # non-interactive mode and is appropriate for programmatic usage. Callers + # that need a more restrictive policy can subclass and override this method. + cfg: dict[str, Any] = {"on_permission_request": PermissionHandler.approve_all} + if params.model_id: + cfg["model"] = params.model_id + # Pass reasoning_effort through; omit values the SDK doesn't accept. + if params.reasoning_effort and params.reasoning_effort not in ("auto", "none"): + cfg["reasoning_effort"] = params.reasoning_effort + if streaming: + cfg["streaming"] = True + return cfg + + async def _stream_from_session( + self, + session: Any, + msg_opts: dict[str, Any], + model_id: str, + temp_paths: list[str], + ) -> "AsyncIterator[ChatCompletionChunk]": + """Async generator that streams chunks from a live Copilot session. + + Bridges the SDK's event-callback model to an async iterator via an + asyncio Queue. The session is disconnected and temp files cleaned up + when the generator exits (normally or via exception/cancellation). + """ + queue: asyncio.Queue[tuple[str, str] | None] = asyncio.Queue() + + def on_event(event: Any) -> None: + etype = event.type + if etype == SessionEventType.ASSISTANT_MESSAGE_DELTA: + queue.put_nowait(("content", event.data.delta_content or "")) + elif etype == SessionEventType.ASSISTANT_REASONING_DELTA: + queue.put_nowait(("reasoning", event.data.delta_content or "")) + elif etype == SessionEventType.SESSION_IDLE: + queue.put_nowait(None) # sentinel — streaming complete normally + elif etype == SessionEventType.SESSION_ERROR: + # SDK SESSION_ERROR event fields are untyped; use getattr with a + # default so this path is safe even if the schema changes. + error_msg = getattr(getattr(event, "data", None), "message", "Copilot session error") + queue.put_nowait(("error", error_msg)) + + unsubscribe = session.on(on_event) + try: + await session.send(msg_opts) + while True: + item = await queue.get() + if item is None: + break + kind, delta = item + if kind == "error": + raise RuntimeError(delta) + yield _build_chunk(delta, model_id, is_reasoning=(kind == "reasoning")) + finally: + unsubscribe() + await session.disconnect() + _cleanup_temp_files(temp_paths) + + @override + async def _acompletion( + self, + params: "CompletionParams", + **kwargs: Any, + ) -> "ChatCompletion | AsyncIterator[ChatCompletionChunk]": + """Send a completion request via a fresh Copilot session.""" + client = await self._ensure_client() + prompt = _messages_to_prompt(params.messages) + model_id = params.model_id or self.PROVIDER_NAME + attachments, temp_paths = _extract_attachments(params.messages) + + session_cfg = self._build_session_cfg(params, streaming=bool(params.stream)) + # A new session is created per request; the Copilot SDK session + # lifecycle is lightweight (no separate process per session). + session = await client.create_session(session_cfg) + + msg_opts: dict[str, Any] = {"prompt": prompt} + if attachments: + msg_opts["attachments"] = attachments + + if params.stream: + # Return the async generator directly; it owns the session lifecycle. + return self._stream_from_session(session, msg_opts, model_id, temp_paths) + + # Non-streaming: capture reasoning alongside the final message event. + try: + reasoning_content: str | None = None + + def on_reasoning(event: Any) -> None: + nonlocal reasoning_content + if event.type == SessionEventType.ASSISTANT_REASONING: + # SDK reasoning event fields are untyped; defensive access is + # intentional here so a schema change doesn't cause an AttributeError. + reasoning_content = ( + getattr(getattr(event, "data", None), "content", None) or None + ) + + unsubscribe = session.on(on_reasoning) + try: + event = await session.send_and_wait(msg_opts) + finally: + unsubscribe() + + content = "" + if event is not None and hasattr(event, "data") and hasattr(event.data, "content"): + content = event.data.content or "" + return _build_chat_completion(content, model_id, reasoning_content) + finally: + await session.disconnect() + _cleanup_temp_files(temp_paths) + + @override + async def _alist_models(self, **kwargs: Any) -> "Sequence[Model]": + """List models available through the Copilot CLI.""" + client = await self._ensure_client() + models = await client.list_models() + return [_copilot_model_to_openai(m) for m in models] + + @staticmethod + @override + def _convert_completion_params(params: "CompletionParams", **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") + + @staticmethod + @override + def _convert_completion_response(response: Any) -> "ChatCompletion": + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") + + @staticmethod + @override + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> "ChatCompletionChunk": + raise NotImplementedError("CopilotsdkProvider overrides _acompletion directly") + + @staticmethod + @override + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("CopilotsdkProvider does not support embeddings") + + @staticmethod + @override + def _convert_embedding_response(response: Any) -> "CreateEmbeddingResponse": + raise NotImplementedError("CopilotsdkProvider does not support embeddings") + + @staticmethod + @override + def _convert_list_models_response(response: Any) -> "Sequence[Model]": + raise NotImplementedError("CopilotsdkProvider uses _alist_models directly") diff --git a/tests/conftest.py b/tests/conftest.py index 5caaea1d..d46d77de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,6 +81,7 @@ def provider_model_map() -> dict[LLMProvider, str]: LLMProvider.LLAMACPP: "N/A", LLMProvider.MINIMAX: "MiniMax-M2", LLMProvider.ZAI: "glm-4-32b-0414-128k", + LLMProvider.COPILOTSDK: "gpt-4o", } @@ -151,6 +152,7 @@ def provider_client_config() -> dict[LLMProvider, dict[str, Any]]: "api_base": "https://mlrun-me8bof5t-eastus2.cognitiveservices.azure.com/", "api_version": "2025-03-01-preview", }, + LLMProvider.COPILOTSDK: {}, } diff --git a/tests/unit/providers/test_copilotsdk_provider.py b/tests/unit/providers/test_copilotsdk_provider.py new file mode 100644 index 00000000..315fca14 --- /dev/null +++ b/tests/unit/providers/test_copilotsdk_provider.py @@ -0,0 +1,631 @@ +"""Unit tests for CopilotsdkProvider.""" +from __future__ import annotations + +import asyncio +import base64 +import os +import tempfile +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from any_llm.providers.copilotsdk.copilotsdk import ( + CopilotsdkProvider, + _build_chat_completion, + _build_chunk, + _cleanup_temp_files, + _copilot_model_to_openai, + _extract_attachments, + _messages_to_prompt, +) +from any_llm.types.completion import CompletionParams + +pytest.importorskip("copilot") + + +def _make_provider(**kwargs: Any) -> CopilotsdkProvider: + """Instantiate CopilotsdkProvider while bypassing real CLI startup.""" + provider = object.__new__(CopilotsdkProvider) + provider._resolved_token = kwargs.get("api_key") + provider._cli_url = kwargs.get("api_base") + provider._cli_path = None + provider._extra_kwargs = {} + provider._copilot_client = None + provider._client_lock = asyncio.Lock() + return provider + + +def _make_model_info(model_id: str = "gpt-4o", name: str = "GPT-4o") -> Any: + """Return a minimal mock ModelInfo.""" + model_info = MagicMock() + model_info.id = model_id + model_info.name = name + return model_info + + +def _make_session(send_and_wait_result: Any = None, send_and_wait_error: Exception | None = None) -> Any: + """Build a mock Copilot session with a sync on() and async methods.""" + session = MagicMock() + session.on = MagicMock(return_value=lambda: None) + session.disconnect = AsyncMock() + if send_and_wait_error: + session.send_and_wait = AsyncMock(side_effect=send_and_wait_error) + else: + session.send_and_wait = AsyncMock(return_value=send_and_wait_result) + return session + + +def _make_data_uri(mime: str = "image/jpeg", content: bytes = b"\xff\xd8\xff") -> str: + return f"data:{mime};base64,{base64.b64encode(content).decode()}" + + +def test_messages_to_prompt_simple_user() -> None: + messages = [{"role": "user", "content": "Hello"}] + assert _messages_to_prompt(messages) == "User: Hello" + + +def test_messages_to_prompt_system_prepended() -> None: + messages = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Hi"}, + ] + result = _messages_to_prompt(messages) + assert result.startswith("Be concise.") + assert "User: Hi" in result + + +def test_messages_to_prompt_multi_turn() -> None: + messages = [ + {"role": "user", "content": "Ping"}, + {"role": "assistant", "content": "Pong"}, + {"role": "user", "content": "Again"}, + ] + result = _messages_to_prompt(messages) + assert "User: Ping" in result + assert "Assistant: Pong" in result + assert "User: Again" in result + + +def test_messages_to_prompt_multimodal_extracts_text() -> None: + messages = [{"role": "user", "content": [{"type": "text", "text": "Tell me"}, {"type": "image_url"}]}] + result = _messages_to_prompt(messages) + assert "Tell me" in result + + +def test_messages_to_prompt_empty_list() -> None: + assert _messages_to_prompt([]) == "" + + +def test_build_chat_completion_structure() -> None: + completion = _build_chat_completion("Hello world", "gpt-4o") + assert completion.choices[0].message.content == "Hello world" + assert completion.model == "gpt-4o" + assert completion.choices[0].finish_reason == "stop" + assert completion.object == "chat.completion" + + +def test_build_chat_completion_with_reasoning() -> None: + completion = _build_chat_completion("Answer", "gpt-4o", reasoning="Because math.") + assert completion.choices[0].message.reasoning is not None + assert completion.choices[0].message.reasoning.content == "Because math." + + +def test_build_chat_completion_no_reasoning() -> None: + completion = _build_chat_completion("Answer", "gpt-4o") + assert completion.choices[0].message.reasoning is None + + +def test_build_chunk_content_delta() -> None: + chunk = _build_chunk("hello", "gpt-4o", is_reasoning=False) + assert chunk.choices[0].delta.content == "hello" + assert chunk.choices[0].delta.reasoning is None + assert chunk.object == "chat.completion.chunk" + + +def test_build_chunk_reasoning_delta() -> None: + chunk = _build_chunk("step 1", "gpt-4o", is_reasoning=True) + assert chunk.choices[0].delta.content is None + assert chunk.choices[0].delta.reasoning is not None + assert chunk.choices[0].delta.reasoning.content == "step 1" + + +def test_copilot_model_to_openai_maps_id() -> None: + info = _make_model_info("claude-sonnet-4-5", "Claude Sonnet 4.5") + model = _copilot_model_to_openai(info) + assert model.id == "claude-sonnet-4-5" + assert model.owned_by == "github-copilot" + assert model.object == "model" + + +def test_provider_required_attributes() -> None: + assert CopilotsdkProvider.PROVIDER_NAME == "copilotsdk" + assert CopilotsdkProvider.SUPPORTS_COMPLETION is True + assert CopilotsdkProvider.SUPPORTS_LIST_MODELS is True + assert CopilotsdkProvider.SUPPORTS_EMBEDDING is False + assert CopilotsdkProvider.SUPPORTS_COMPLETION_STREAMING is True + assert CopilotsdkProvider.SUPPORTS_COMPLETION_REASONING is True + assert CopilotsdkProvider.SUPPORTS_COMPLETION_IMAGE is True + assert CopilotsdkProvider.MISSING_PACKAGES_ERROR is None + + +def test_api_key_explicit_wins() -> None: + provider = object.__new__(CopilotsdkProvider) + with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "env-token"}, clear=False): + result = provider._verify_and_set_api_key("explicit-token") + assert result == "explicit-token" + + +def test_api_key_falls_back_to_copilot_env() -> None: + provider = object.__new__(CopilotsdkProvider) + with patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "copilot-env"}, clear=False): + result = provider._verify_and_set_api_key(None) + assert result == "copilot-env" + + +def test_api_key_falls_back_to_github_token() -> None: + provider = object.__new__(CopilotsdkProvider) + env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "gh-token"} + with patch.dict(os.environ, env, clear=False): + result = provider._verify_and_set_api_key(None) + assert result == "gh-token" + + +def test_api_key_returns_none_when_all_absent() -> None: + provider = object.__new__(CopilotsdkProvider) + env = {"COPILOT_GITHUB_TOKEN": "", "GITHUB_TOKEN": "", "GH_TOKEN": ""} + with patch.dict(os.environ, env, clear=False): + result = provider._verify_and_set_api_key(None) + assert result is None + + +def test_init_client_stores_token_and_url() -> None: + provider = object.__new__(CopilotsdkProvider) + provider._init_client(api_key="my-token", api_base="localhost:9000") + assert provider._resolved_token == "my-token" + assert provider._cli_url == "localhost:9000" + assert provider._copilot_client is None + + +def test_init_client_reads_cli_url_from_env() -> None: + provider = object.__new__(CopilotsdkProvider) + with patch.dict(os.environ, {"COPILOT_CLI_URL": "localhost:7777"}, clear=False): + provider._init_client(api_key=None, api_base=None) + assert provider._cli_url == "localhost:7777" + + +def test_extract_attachments_base64_image() -> None: + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + {"type": "image_url", "image_url": {"url": _make_data_uri()}}, + ], + }] + attachments, temp_paths = _extract_attachments(messages) + try: + assert len(attachments) == 1 + assert attachments[0]["type"] == "file" + assert os.path.exists(attachments[0]["path"]) + assert len(temp_paths) == 1 + finally: + _cleanup_temp_files(temp_paths) + + +def test_extract_attachments_http_url_skipped() -> None: + messages = [{ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}], + }] + attachments, temp_paths = _extract_attachments(messages) + assert attachments == [] + assert temp_paths == [] + + +def test_extract_attachments_no_images() -> None: + attachments, temp_paths = _extract_attachments([{"role": "user", "content": "Plain text, no images."}]) + assert attachments == [] + assert temp_paths == [] + + +def test_extract_attachments_malformed_data_uri_skipped() -> None: + messages = [{ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,NOT_VALID!!"}}], + }] + attachments, temp_paths = _extract_attachments(messages) + assert attachments == [] + assert temp_paths == [] + + +def test_cleanup_temp_files_removes_files() -> None: + with tempfile.NamedTemporaryFile(delete=False) as handle: + path = handle.name + assert os.path.exists(path) + _cleanup_temp_files([path]) + assert not os.path.exists(path) + + +def test_cleanup_temp_files_tolerates_missing() -> None: + _cleanup_temp_files(["/tmp/does-not-exist-copilotsdk-test-xyz"]) + + +@pytest.mark.asyncio +async def test_acompletion_returns_chat_completion() -> None: + provider = _make_provider(api_key="test-token") + + mock_event = MagicMock() + mock_event.data.content = "The answer is 4." + + mock_session = _make_session(send_and_wait_result=mock_event) + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "The answer is 4." + assert result.model == "gpt-4o" + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_handles_none_event() -> None: + provider = _make_provider(api_key="test-token") + + mock_session = _make_session(send_and_wait_result=None) + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "" + + +@pytest.mark.asyncio +async def test_acompletion_disconnects_session_on_error() -> None: + provider = _make_provider(api_key="test-token") + + mock_session = _make_session(send_and_wait_error=RuntimeError("CLI died")) + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + with pytest.raises(RuntimeError, match="CLI died"): + await provider._acompletion(params) + + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_alist_models_converts_model_infos() -> None: + provider = _make_provider(api_key="test-token") + + raw_models = [ + _make_model_info("gpt-4o", "GPT-4o"), + _make_model_info("claude-sonnet-4-5", "Claude Sonnet 4.5"), + ] + mock_client = AsyncMock() + mock_client.list_models = AsyncMock(return_value=raw_models) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + models = await provider._alist_models() + + assert len(models) == 2 + assert {model.id for model in models} == {"gpt-4o", "claude-sonnet-4-5"} + for model in models: + assert model.owned_by == "github-copilot" + + +@pytest.mark.asyncio +async def test_ensure_client_reuses_existing_instance() -> None: + provider = _make_provider(api_key="test-token") + + mock_client = AsyncMock() + mock_client.start = AsyncMock() + provider._copilot_client = mock_client + + result = await provider._ensure_client() + + assert result is mock_client + mock_client.start.assert_not_called() + + +@pytest.mark.asyncio +async def test_ensure_client_uses_cli_url_path() -> None: + provider = _make_provider(api_key="tok", api_base="localhost:9000") + + mock_client = AsyncMock() + mock_client.start = AsyncMock() + + with patch("any_llm.providers.copilotsdk.copilotsdk.CopilotClient", return_value=mock_client) as mock_cls: + result = await provider._ensure_client() + + assert result is mock_client + mock_cls.assert_called_once_with({"cli_url": "localhost:9000"}) + mock_client.start.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_captures_reasoning() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + message_event = MagicMock() + message_event.data.content = "42" + message_event.type = SessionEventType.ASSISTANT_MESSAGE + + reasoning_event = MagicMock() + reasoning_event.data.content = "Because 6×7=42." + reasoning_event.type = SessionEventType.ASSISTANT_REASONING + + def fake_on(callback: Any) -> Any: + callback(reasoning_event) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.send_and_wait = AsyncMock(return_value=message_event) + mock_session.disconnect = AsyncMock() + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "What is 6×7?"}], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "42" + assert result.choices[0].message.reasoning is not None + assert result.choices[0].message.reasoning.content == "Because 6×7=42." + + +@pytest.mark.asyncio +async def test_acompletion_passes_reasoning_effort_to_session() -> None: + provider = _make_provider(api_key="test-token") + + mock_session = _make_session() + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Think hard."}], + reasoning_effort="high", + ) + await provider._acompletion(params) + + call_cfg = mock_client.create_session.call_args[0][0] + assert call_cfg.get("reasoning_effort") == "high" + + +@pytest.mark.asyncio +async def test_acompletion_omits_auto_reasoning_effort() -> None: + provider = _make_provider(api_key="test-token") + + mock_session = _make_session() + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + reasoning_effort="auto", + ) + await provider._acompletion(params) + + call_cfg = mock_client.create_session.call_args[0][0] + assert "reasoning_effort" not in call_cfg + + +@pytest.mark.asyncio +async def test_acompletion_streaming_yields_chunks() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + def make_event(event_type: SessionEventType, delta: str) -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.delta_content = delta + return event + + delta_events = [ + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, "Hello"), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, " world"), + make_event(SessionEventType.SESSION_IDLE, ""), + ] + + registered_callback: list[Any] = [] + + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for event in delta_events: + registered_callback[0](event) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + stream = await provider._acompletion(params) + chunks = [chunk async for chunk in stream] + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello" + assert chunks[1].choices[0].delta.content == " world" + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_streaming_reasoning_chunks() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + def make_event(event_type: SessionEventType, delta: str) -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.delta_content = delta + return event + + delta_events = [ + make_event(SessionEventType.ASSISTANT_REASONING_DELTA, "step 1"), + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA, "answer"), + make_event(SessionEventType.SESSION_IDLE, ""), + ] + + registered_callback: list[Any] = [] + + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for event in delta_events: + registered_callback[0](event) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Think step by step."}], + stream=True, + ) + stream = await provider._acompletion(params) + chunks = [chunk async for chunk in stream] + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.reasoning is not None + assert chunks[0].choices[0].delta.reasoning.content == "step 1" + assert chunks[0].choices[0].delta.content is None + assert chunks[1].choices[0].delta.content == "answer" + assert chunks[1].choices[0].delta.reasoning is None + + +@pytest.mark.asyncio +async def test_acompletion_streaming_session_error_raises() -> None: + from copilot.generated.session_events import SessionEventType + + provider = _make_provider(api_key="test-token") + + def make_event(event_type: SessionEventType, message: str = "") -> MagicMock: + event = MagicMock() + event.type = event_type + event.data.message = message + event.data.delta_content = "" + return event + + delta_events = [ + make_event(SessionEventType.ASSISTANT_MESSAGE_DELTA), + make_event(SessionEventType.SESSION_ERROR, "CLI crashed"), + ] + delta_events[0].data.delta_content = "partial" + + registered_callback: list[Any] = [] + + def fake_on(callback: Any) -> Any: + registered_callback.append(callback) + return lambda: None + + mock_session = MagicMock() + mock_session.on = MagicMock(side_effect=fake_on) + mock_session.disconnect = AsyncMock() + + async def fake_send(opts: Any) -> str: + for event in delta_events: + registered_callback[0](event) + return "msg-id" + + mock_session.send = fake_send + + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + stream = await provider._acompletion(params) + with pytest.raises(RuntimeError): + async for _ in stream: + pass + + mock_session.disconnect.assert_called_once() + + +@pytest.mark.asyncio +async def test_acompletion_image_attachments_forwarded() -> None: + provider = _make_provider(api_key="test-token") + + jpeg_data = base64.b64encode(b"\xff\xd8\xff").decode() + data_uri = f"data:image/jpeg;base64,{jpeg_data}" + + mock_event = MagicMock() + mock_event.data.content = "I see an image." + mock_session = _make_session(send_and_wait_result=mock_event) + mock_client = AsyncMock() + mock_client.create_session = AsyncMock(return_value=mock_session) + + with patch.object(provider, "_ensure_client", AsyncMock(return_value=mock_client)): + params = CompletionParams( + model_id="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ], + }], + ) + result = await provider._acompletion(params) + + assert result.choices[0].message.content == "I see an image." + call_kwargs = mock_session.send_and_wait.call_args[0][0] + assert "attachments" in call_kwargs + assert len(call_kwargs["attachments"]) == 1 + assert call_kwargs["attachments"][0]["type"] == "file" \ No newline at end of file diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index 3f3b8bb9..75038550 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -149,6 +149,7 @@ def test_providers_raise_MissingApiKeyError(provider: LLMProvider) -> None: LLMProvider.VERTEXAI, LLMProvider.VLLM, LLMProvider.GATEWAY, + LLMProvider.COPILOTSDK, # uses CLI auth; no API key required ): pytest.skip("This provider handles `api_key` differently.") with patch.dict(os.environ, {}, clear=True):