diff --git a/litellm/llms/litellm_proxy/skills/handler.py b/litellm/llms/litellm_proxy/skills/handler.py index 8e5070c2724..6ce2e2f2dfb 100644 --- a/litellm/llms/litellm_proxy/skills/handler.py +++ b/litellm/llms/litellm_proxy/skills/handler.py @@ -71,6 +71,17 @@ async def create_skill( """ prisma_client = await LiteLLMSkillsHandler._get_prisma_client() + # Enforce unique display_title + if data.display_title: + existing = await prisma_client.db.litellm_skillstable.find_first( + where={"display_title": data.display_title} + ) + if existing is not None: + raise ValueError( + f"A skill with display_title '{data.display_title}' already exists " + f"(id: {existing.skill_id}). Skill names must be unique." + ) + skill_id = f"litellm_skill_{uuid.uuid4()}" skill_data: Dict[str, Any] = { @@ -217,3 +228,76 @@ async def fetch_skill_from_db(skill_id: str) -> Optional[LiteLLM_SkillsTable]: f"LiteLLMSkillsHandler: Error fetching skill {skill_id}: {e}" ) return None + + @staticmethod + async def save_provider_skill_id( + skill_id: str, + provider: str, + provider_skill_id: str, + ) -> None: + """ + Save a provider-assigned skill ID for a LiteLLM skill. + + When a skill is used with a provider that has a native skills API + (e.g. Anthropic), the provider returns its own skill ID. We store + that mapping so subsequent calls can reuse it without re-creating. + + Stored in metadata._provider_skill_ids.{provider} = provider_skill_id + + Args: + skill_id: The LiteLLM skill ID + provider: Provider name (e.g. "anthropic") + provider_skill_id: The ID assigned by the provider + """ + import json + + prisma_client = await LiteLLMSkillsHandler._get_prisma_client() + + # Fetch current metadata + skill = await prisma_client.db.litellm_skillstable.find_unique( + where={"skill_id": skill_id} + ) + if skill is None: + verbose_logger.warning( + f"LiteLLMSkillsHandler: Cannot save provider ID - skill {skill_id} not found" + ) + return + + metadata = skill.metadata if isinstance(skill.metadata, dict) else {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + + provider_ids = metadata.get("_provider_skill_ids", {}) + provider_ids[provider] = provider_skill_id + metadata["_provider_skill_ids"] = provider_ids + + await prisma_client.db.litellm_skillstable.update( + where={"skill_id": skill_id}, + data={"metadata": json.dumps(metadata)}, + ) + + verbose_logger.debug( + f"LiteLLMSkillsHandler: Saved {provider} skill ID " + f"'{provider_skill_id}' for skill {skill_id}" + ) + + @staticmethod + def get_provider_skill_id( + skill: LiteLLM_SkillsTable, + provider: str, + ) -> Optional[str]: + """ + Get the provider-assigned skill ID from a skill's metadata. + + Args: + skill: The LiteLLM skill record + provider: Provider name (e.g. "anthropic") + + Returns: + The provider's skill ID, or None if not yet registered + """ + if not skill.metadata or not isinstance(skill.metadata, dict): + return None + + provider_ids = skill.metadata.get("_provider_skill_ids", {}) + return provider_ids.get(provider) diff --git a/litellm/llms/litellm_proxy/skills/skill_applicator.py b/litellm/llms/litellm_proxy/skills/skill_applicator.py new file mode 100644 index 00000000000..937b9257415 --- /dev/null +++ b/litellm/llms/litellm_proxy/skills/skill_applicator.py @@ -0,0 +1,194 @@ +""" +Skill Applicator for Gateway Skills. + +Handles provider-specific strategies for applying skills to LLM requests. +Uses get_llm_provider() to resolve models to providers, then checks the +centralized beta headers config to determine if the model's provider +supports native skills (skills-2025-10-02 beta). If not, falls back to +system prompt injection. +""" + +from typing import List, Optional + +from litellm._logging import verbose_logger +from litellm.constants import ANTHROPIC_SKILLS_API_BETA_VERSION +from litellm.proxy._types import LiteLLM_SkillsTable + + +class SkillApplicator: + """ + Applies gateway skills to LLM requests using provider-specific strategies. + + Provider resolution is delegated to litellm.get_llm_provider(). + Native skills support is determined by the centralized beta headers + config (anthropic_beta_headers_config.json) — if the provider maps + skills-2025-10-02 to a non-null value, native skills are supported. + """ + + def __init__(self): + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + self.prompt_handler = SkillPromptInjectionHandler() + + def supports_native_skills(self, provider: str) -> bool: + """ + Check if a provider supports native skills by consulting the + centralized beta headers config. + """ + from litellm.anthropic_beta_headers_manager import is_beta_header_supported + + return is_beta_header_supported( + beta_header=ANTHROPIC_SKILLS_API_BETA_VERSION, + provider=provider, + ) + + async def apply_skills( + self, + data: dict, + skills: List[LiteLLM_SkillsTable], + provider: str, + ) -> dict: + """ + Apply skills to a request based on provider. + + Args: + data: The request data dict + skills: List of skills to apply + provider: The LLM provider name (from get_llm_provider) + + Returns: + Modified request data with skills applied + """ + if not skills: + return data + + if self.supports_native_skills(provider): + verbose_logger.debug( + f"SkillApplicator: Applying {len(skills)} skills via native API " + f"for provider={provider}" + ) + return self._apply_tool_conversion_strategy(data, skills) + + verbose_logger.debug( + f"SkillApplicator: Applying {len(skills)} skills via system prompt " + f"for provider={provider}" + ) + return self._apply_system_prompt_strategy(data, skills) + + def _apply_system_prompt_strategy( + self, + data: dict, + skills: List[LiteLLM_SkillsTable], + ) -> dict: + """ + Apply skills by injecting content into system prompt. + + Format: + --- + ## Skill: {display_title} + **Description:** {description} + + ### Instructions + {SKILL.md body content} + --- + """ + skill_contents: List[str] = [] + + for skill in skills: + content = self._format_skill_content(skill) + if content: + skill_contents.append(content) + + if not skill_contents: + return data + + return self.prompt_handler.inject_skill_content_to_messages( + data, skill_contents, use_anthropic_format=False + ) + + def _apply_tool_conversion_strategy( + self, + data: dict, + skills: List[LiteLLM_SkillsTable], + ) -> dict: + """ + Apply skills by converting to Anthropic-style tools + system prompt. + """ + tools = data.get("tools", []) + skill_contents: List[str] = [] + + for skill in skills: + tools.append(self.prompt_handler.convert_skill_to_anthropic_tool(skill)) + + content = self.prompt_handler.extract_skill_content(skill) + if content: + skill_contents.append(content) + + if tools: + data["tools"] = tools + + if skill_contents: + data = self.prompt_handler.inject_skill_content_to_messages( + data, skill_contents, use_anthropic_format=True + ) + + return data + + def _format_skill_content(self, skill: LiteLLM_SkillsTable) -> Optional[str]: + """ + Format skill content for system prompt injection. + """ + content = self.prompt_handler.extract_skill_content(skill) + + if not content: + content = skill.instructions + + if not content: + return None + + title = skill.display_title or skill.skill_id + parts = [f"## Skill: {title}"] + + if skill.description: + parts.append(f"**Description:** {skill.description}") + + parts.append("") + parts.append("### Instructions") + parts.append(content) + + return "\n".join(parts) + + +def get_provider_from_model(model: str) -> str: + """ + Determine the provider from a model string. + + First checks the proxy router's model list to resolve aliases + (e.g., "claude-sonnet" -> "anthropic/claude-sonnet-4-20250514"), + then uses get_llm_provider on the resolved model. + """ + resolved_model = model + + # Try to resolve through the router's model list + try: + from litellm.proxy.proxy_server import llm_router + + if llm_router is not None: + deployments = llm_router.get_model_list(model_name=model) + if deployments: + resolved_model = deployments[0]["litellm_params"]["model"] + except Exception: + pass + + try: + from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider + + _, custom_llm_provider, _, _ = get_llm_provider(model=resolved_model) + return custom_llm_provider or "openai" + except Exception as e: + verbose_logger.warning( + f"SkillApplicator: Failed to determine provider for model {model}: {e}" + ) + return "openai" diff --git a/litellm/proxy/anthropic_endpoints/skills_endpoints.py b/litellm/proxy/anthropic_endpoints/skills_endpoints.py index cd19e7731f0..85c29dde26c 100644 --- a/litellm/proxy/anthropic_endpoints/skills_endpoints.py +++ b/litellm/proxy/anthropic_endpoints/skills_endpoints.py @@ -1,13 +1,19 @@ """ -Anthropic Skills API endpoints - /v1/skills +Skills API endpoints - /v1/skills + +Supports two modes controlled by litellm_settings.skills_mode: +- "litellm": Skills stored in LiteLLM DB, works with any model provider +- "passthrough": Pass-through to Anthropic API (requires Anthropic model) """ -from typing import Optional +from typing import Literal, Optional import orjson -from fastapi import APIRouter, Depends, Request, Response +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from starlette.datastructures import UploadFile -from litellm.proxy._types import UserAPIKeyAuth +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import NewSkillRequest, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing from litellm.proxy.common_utils.http_parsing_utils import ( @@ -23,9 +29,208 @@ router = APIRouter() +def get_skills_mode() -> Literal["litellm", "passthrough"]: + """ + Get the skills_mode from litellm_settings. + + Returns: + "litellm" - Skills managed by LiteLLM (stored in DB, works with any provider) + "passthrough" - Pass-through to Anthropic API (default for backwards compatibility) + """ + from litellm.proxy.proxy_server import general_settings + + # Check general_settings for skills_mode + skills_mode = general_settings.get("skills_mode") + + if skills_mode is None: + return "passthrough" + + if skills_mode not in ("litellm", "passthrough"): + verbose_proxy_logger.warning( + f"Invalid skills_mode '{skills_mode}', defaulting to 'passthrough'" + ) + return "passthrough" + + return skills_mode + + +async def _handle_litellm_create_skill( + request: Request, + user_api_key_dict: UserAPIKeyAuth, +) -> Skill: + """Handle skill creation in LiteLLM mode (local DB storage).""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + from litellm.proxy.skills_endpoints.validation import validate_skill_files + + # Parse form data + form_data = await get_form_data(request) + + # Get display_title override if provided + display_title_override = form_data.get("display_title") + + # Get files from form data + # get_form_data strips [] suffix, so "files[]" becomes "files" + files_data = form_data.get("files", []) + if not files_data: + files_data = form_data.get("files[]", []) + + if not files_data: + raise HTTPException( + status_code=400, + detail="No files provided. SKILL.md is required.", + ) + + # Normalize to list if single file + if not isinstance(files_data, list): + files_data = [files_data] + + # Read file contents + file_tuples = [] + for file_item in files_data: + if isinstance(file_item, UploadFile): + content = await file_item.read() + filename = file_item.filename or "unknown" + file_tuples.append((filename, content)) + elif isinstance(file_item, tuple) and len(file_item) >= 2: + filename, content = file_item[0], file_item[1] + if isinstance(content, str): + content = content.encode("utf-8") + file_tuples.append((filename, content)) + + if not file_tuples: + raise HTTPException( + status_code=400, + detail="No valid files provided. SKILL.md is required.", + ) + + # Validate files and create ZIP + zip_content, frontmatter, body, errors = validate_skill_files(file_tuples) + + if errors: + raise HTTPException( + status_code=400, + detail={"errors": errors}, + ) + + assert zip_content is not None + assert frontmatter is not None + + # Create skill request + skill_request = NewSkillRequest( + display_title=display_title_override or frontmatter.name, + description=frontmatter.description, + instructions=body, + file_content=zip_content, + file_name="skill.zip", + file_type="application/zip", + ) + + # Create skill in DB + try: + skill_record = await LiteLLMSkillsHandler.create_skill( + data=skill_request, + user_id=user_api_key_dict.user_id, + ) + except ValueError as e: + raise HTTPException(status_code=409, detail=str(e)) + + verbose_proxy_logger.debug(f"Created LiteLLM skill: {skill_record.skill_id}") + + return Skill( + id=skill_record.skill_id, + display_title=skill_record.display_title, + source=skill_record.source, + latest_version=skill_record.latest_version, + created_at=skill_record.created_at.isoformat() + if skill_record.created_at + else "", + updated_at=skill_record.updated_at.isoformat() + if skill_record.updated_at + else "", + ) + + +async def _handle_litellm_list_skills( + limit: int = 20, + page: Optional[str] = None, +) -> ListSkillsResponse: + """Handle skill listing in LiteLLM mode (local DB).""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + # Clamp limit + limit = max(1, min(limit, 100)) + + # Parse page to offset + offset = 0 + if page: + try: + offset = int(page) + except ValueError: + pass + + # Fetch from DB + skills = await LiteLLMSkillsHandler.list_skills( + limit=limit + 1, + offset=offset, + ) + + has_more = len(skills) > limit + if has_more: + skills = skills[:limit] + + skill_responses = [ + Skill( + id=s.skill_id, + display_title=s.display_title, + source=s.source, + latest_version=s.latest_version, + created_at=s.created_at.isoformat() if s.created_at else "", + updated_at=s.updated_at.isoformat() if s.updated_at else "", + ) + for s in skills + ] + + return ListSkillsResponse( + data=skill_responses, + has_more=has_more, + next_page=str(offset + limit) if has_more else None, + ) + + +async def _handle_litellm_get_skill(skill_id: str) -> Skill: + """Handle skill retrieval in LiteLLM mode (local DB).""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + try: + skill = await LiteLLMSkillsHandler.get_skill(skill_id) + except ValueError: + raise HTTPException(status_code=404, detail=f"Skill not found: {skill_id}") + + return Skill( + id=skill.skill_id, + display_title=skill.display_title, + source=skill.source, + latest_version=skill.latest_version, + created_at=skill.created_at.isoformat() if skill.created_at else "", + updated_at=skill.updated_at.isoformat() if skill.updated_at else "", + ) + + +async def _handle_litellm_delete_skill(skill_id: str) -> DeleteSkillResponse: + """Handle skill deletion in LiteLLM mode (local DB).""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + try: + result = await LiteLLMSkillsHandler.delete_skill(skill_id) + except ValueError: + raise HTTPException(status_code=404, detail=f"Skill not found: {skill_id}") + + return DeleteSkillResponse(id=result["id"], type=result["type"]) + + @router.post( "/v1/skills", - tags=["[beta] Anthropic Skills API"], + tags=["[beta] Skills API"], dependencies=[Depends(user_api_key_auth)], response_model=Skill, ) @@ -36,35 +241,38 @@ async def create_skill( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - Create a new skill on Anthropic. - - Requires `?beta=true` query parameter. - - Model-based routing (for multi-account support): - - Pass model via header: `x-litellm-model: claude-account-1` - - Pass model via query: `?model=claude-account-1` - - Pass model via form field: `model=claude-account-1` - + Create a new skill. + + Behavior depends on `litellm_settings.skills_mode`: + - "litellm": Stores skill in LiteLLM DB, works with any provider + - "passthrough": Creates skill on Anthropic (requires Anthropic model) + + SKILL.md must have YAML frontmatter (for litellm mode): + ```yaml + --- + name: My Skill (max 64 chars) + description: What this skill does (max 1024 chars, optional) + --- + ``` + Example usage: ```bash - # Basic usage - curl -X POST "http://localhost:4000/v1/skills?beta=true" \ - -H "Content-Type: multipart/form-data" \ - -H "Authorization: Bearer your-key" \ - -F "display_title=My Skill" \ - -F "files[]=@skill.zip" - - # With model-based routing - curl -X POST "http://localhost:4000/v1/skills?beta=true" \ - -H "Content-Type: multipart/form-data" \ - -H "Authorization: Bearer your-key" \ - -H "x-litellm-model: claude-account-1" \ - -F "display_title=My Skill" \ - -F "files[]=@skill.zip" + curl -X POST "http://localhost:4000/v1/skills" \\ + -H "Content-Type: multipart/form-data" \\ + -H "Authorization: Bearer your-key" \\ + -F "display_title=My Skill" \\ + -F "files[]=@SKILL.md" ``` - + Returns: Skill object with id, display_title, etc. """ + # Check skills mode + skills_mode = get_skills_mode() + + if skills_mode == "litellm": + return await _handle_litellm_create_skill(request, user_api_key_dict) + + # Passthrough mode - forward to Anthropic from litellm.proxy.proxy_server import ( general_settings, llm_router, @@ -127,43 +335,46 @@ async def create_skill( @router.get( "/v1/skills", - tags=["[beta] Anthropic Skills API"], + tags=["[beta] Skills API"], dependencies=[Depends(user_api_key_auth)], response_model=ListSkillsResponse, ) async def list_skills( fastapi_response: Response, request: Request, - limit: Optional[int] = 10, + limit: Optional[int] = 20, + page: Optional[str] = None, after_id: Optional[str] = None, before_id: Optional[str] = None, custom_llm_provider: Optional[str] = "anthropic", user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - List skills on Anthropic. - - Requires `?beta=true` query parameter. - - Model-based routing (for multi-account support): - - Pass model via header: `x-litellm-model: claude-account-1` - - Pass model via query: `?model=claude-account-1` - - Pass model via body: `{"model": "claude-account-1"}` - + List skills. + + Behavior depends on `litellm_settings.skills_mode`: + - "litellm": Lists skills from LiteLLM DB + - "passthrough": Lists skills from Anthropic + + Query parameters: + - limit: Number of results (default 20, max 100) + - page: Pagination token (litellm mode only) + Example usage: ```bash - # Basic usage - curl "http://localhost:4000/v1/skills?beta=true&limit=10" \ + curl "http://localhost:4000/v1/skills?limit=10" \\ -H "Authorization: Bearer your-key" - - # With model-based routing - curl "http://localhost:4000/v1/skills?beta=true&limit=10" \ - -H "Authorization: Bearer your-key" \ - -H "x-litellm-model: claude-account-1" ``` - + Returns: ListSkillsResponse with list of skills """ + # Check skills mode + skills_mode = get_skills_mode() + + if skills_mode == "litellm": + return await _handle_litellm_list_skills(limit=limit or 20, page=page) + + # Passthrough mode - forward to Anthropic from litellm.proxy.proxy_server import ( general_settings, llm_router, @@ -235,7 +446,7 @@ async def list_skills( @router.get( "/v1/skills/{skill_id}", - tags=["[beta] Anthropic Skills API"], + tags=["[beta] Skills API"], dependencies=[Depends(user_api_key_auth)], response_model=Skill, ) @@ -247,29 +458,27 @@ async def get_skill( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - Get a specific skill by ID from Anthropic. - - Requires `?beta=true` query parameter. - - Model-based routing (for multi-account support): - - Pass model via header: `x-litellm-model: claude-account-1` - - Pass model via query: `?model=claude-account-1` - - Pass model via body: `{"model": "claude-account-1"}` - + Get a specific skill by ID. + + Behavior depends on `litellm_settings.skills_mode`: + - "litellm": Gets skill from LiteLLM DB + - "passthrough": Gets skill from Anthropic + Example usage: ```bash - # Basic usage - curl "http://localhost:4000/v1/skills/skill_123?beta=true" \ + curl "http://localhost:4000/v1/skills/litellm_skill_123" \\ -H "Authorization: Bearer your-key" - - # With model-based routing - curl "http://localhost:4000/v1/skills/skill_123?beta=true" \ - -H "Authorization: Bearer your-key" \ - -H "x-litellm-model: claude-account-1" ``` - + Returns: Skill object """ + # Check skills mode + skills_mode = get_skills_mode() + + if skills_mode == "litellm": + return await _handle_litellm_get_skill(skill_id) + + # Passthrough mode - forward to Anthropic from litellm.proxy.proxy_server import ( general_settings, llm_router, @@ -336,7 +545,7 @@ async def get_skill( @router.delete( "/v1/skills/{skill_id}", - tags=["[beta] Anthropic Skills API"], + tags=["[beta] Skills API"], dependencies=[Depends(user_api_key_auth)], response_model=DeleteSkillResponse, ) @@ -348,31 +557,27 @@ async def delete_skill( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - Delete a skill by ID from Anthropic. - - Requires `?beta=true` query parameter. - - Note: Anthropic does not allow deleting skills with existing versions. - - Model-based routing (for multi-account support): - - Pass model via header: `x-litellm-model: claude-account-1` - - Pass model via query: `?model=claude-account-1` - - Pass model via body: `{"model": "claude-account-1"}` - + Delete a skill by ID. + + Behavior depends on `litellm_settings.skills_mode`: + - "litellm": Deletes skill from LiteLLM DB + - "passthrough": Deletes skill from Anthropic + Example usage: ```bash - # Basic usage - curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \ + curl -X DELETE "http://localhost:4000/v1/skills/litellm_skill_123" \\ -H "Authorization: Bearer your-key" - - # With model-based routing - curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \ - -H "Authorization: Bearer your-key" \ - -H "x-litellm-model: claude-account-1" ``` - + Returns: DeleteSkillResponse with type="skill_deleted" """ + # Check skills mode + skills_mode = get_skills_mode() + + if skills_mode == "litellm": + return await _handle_litellm_delete_skill(skill_id) + + # Passthrough mode - forward to Anthropic from litellm.proxy.proxy_server import ( general_settings, llm_router, diff --git a/litellm/proxy/hooks/litellm_skills/main.py b/litellm/proxy/hooks/litellm_skills/main.py index 83e419bc23c..c4b367abf57 100644 --- a/litellm/proxy/hooks/litellm_skills/main.py +++ b/litellm/proxy/hooks/litellm_skills/main.py @@ -14,12 +14,10 @@ execution automatically and returns final response with file_ids Usage: - # Simple - LiteLLM handles everything automatically via proxy - # The container parameter triggers the SkillsInjectionHook response = await litellm.acompletion( model="gpt-4o-mini", messages=[{"role": "user", "content": "Create a bouncing ball GIF"}], - container={"skills": [{"skill_id": "litellm:skill_abc123"}]}, + container={"skills": [{"skill_id": "litellm_skill_abc123"}]}, ) # Response includes file_ids for generated files """ @@ -43,7 +41,7 @@ class SkillsInjectionHook(CustomLogger): Pre/Post-call hook that processes skills from container.skills parameter. Pre-call (async_pre_call_hook): - - Skills with 'litellm:' prefix are fetched from LiteLLM DB + - Skills with 'litellm_' prefix are fetched from LiteLLM DB - For Anthropic models: native skills pass through, LiteLLM skills converted to tools - For non-Anthropic models: LiteLLM skills are converted to tools + execute_code tool @@ -78,7 +76,7 @@ async def async_pre_call_hook( Process skills from container.skills before the LLM call. 1. Check if container.skills exists in request - 2. Separate skills by prefix (litellm: vs native) + 2. Separate skills by prefix (litellm_ vs native) 3. Fetch LiteLLM skills from database 4. For Anthropic: keep native skills in container 5. For non-Anthropic: convert LiteLLM skills to tools, inject content, add execute_code @@ -99,38 +97,84 @@ async def async_pre_call_hook( f"SkillsInjectionHook: Processing {len(skills)} skills" ) + from fastapi import HTTPException + + from litellm.llms.litellm_proxy.skills.skill_applicator import ( + SkillApplicator, + get_provider_from_model, + ) + + model = data.get("model", "") + provider = get_provider_from_model(model) + applicator = SkillApplicator() + litellm_skills: List[LiteLLM_SkillsTable] = [] anthropic_skills: List[Dict[str, Any]] = [] - # Separate skills by prefix + # Classify and validate skills for skill in skills: if not isinstance(skill, dict): continue skill_id = skill.get("skill_id", "") - if skill_id.startswith("litellm_"): - # Fetch from LiteLLM DB + + if skill_id.startswith("litellm_skill_"): + # LiteLLM gateway-managed skill — fetch from DB db_skill = await self._fetch_skill_from_db(skill_id) if db_skill: litellm_skills.append(db_skill) else: - verbose_proxy_logger.warning( - f"SkillsInjectionHook: Skill '{skill_id}' not found in LiteLLM DB" + raise HTTPException( + status_code=404, + detail=f"Skill not found: {skill_id}", + ) + elif skill_id.startswith("skill_"): + # Native Anthropic skill — only allowed with native-skills providers + if not applicator.supports_native_skills(provider): + raise HTTPException( + status_code=400, + detail=f"Anthropic skill '{skill_id}' cannot be used with " + f"model '{model}' (provider '{provider}' does not support " + f"native skills). Use a litellm_skill_* ID instead.", ) - else: - # Native Anthropic skill - pass through anthropic_skills.append(skill) + else: + raise HTTPException( + status_code=400, + detail=f"Invalid skill_id '{skill_id}'. Must start with " + f"'litellm_skill_' (gateway skill) or 'skill_' (Anthropic native).", + ) - # Check if using messages API spec (anthropic_messages call type) - # Messages API always uses Anthropic-style tool format + # When the request comes through /v1/messages (anthropic_messages), + # we must inject into the top-level 'system' param because + # anthropic_messages() has separate 'messages' and 'system' params. use_anthropic_format = call_type == "anthropic_messages" - if len(litellm_skills) > 0: + if litellm_skills and applicator.supports_native_skills(provider): + # Native skills path: convert to tools + system prompt data = self._process_for_messages_api( data=data, litellm_skills=litellm_skills, use_anthropic_format=use_anthropic_format, ) + elif litellm_skills: + # Non-native path: inject into system prompt only + skill_contents = [] + for skill in litellm_skills: + content = applicator._format_skill_content(skill) + if content: + skill_contents.append(content) + + if skill_contents: + data = self.prompt_handler.inject_skill_content_to_messages( + data, skill_contents, use_anthropic_format=use_anthropic_format + ) + + # Rebuild container: keep only native Anthropic skills + if anthropic_skills: + data["container"] = {"skills": anthropic_skills} + else: + data.pop("container", None) return data @@ -906,9 +950,5 @@ def _attach_files_to_response( return response -# Global instance for registration +# Global instance for registration (registered when skills_mode is enabled) skills_injection_hook = SkillsInjectionHook() - -import litellm - -litellm.logging_callback_manager.add_litellm_callback(skills_injection_hook) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e982c934aa6..f375889db8a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3236,6 +3236,17 @@ async def load_config( # noqa: PLR0915 _license_check.license_str = general_settings["litellm_license"] premium_user = _license_check.is_premium() + ### SKILLS MODE ### + skills_mode = general_settings.get("skills_mode", None) + if skills_mode == "litellm": + from litellm.proxy.hooks.litellm_skills.main import ( + skills_injection_hook, + ) + + litellm.logging_callback_manager.add_litellm_callback( + skills_injection_hook + ) + router_params: dict = { "cache_responses": litellm.cache is not None, # cache if user passed in cache values diff --git a/litellm/proxy/skills_endpoints/__init__.py b/litellm/proxy/skills_endpoints/__init__.py new file mode 100644 index 00000000000..9ba5ce23b2b --- /dev/null +++ b/litellm/proxy/skills_endpoints/__init__.py @@ -0,0 +1,18 @@ +""" +Skills validation utilities for LiteLLM Proxy. + +This module provides validation utilities for skill file uploads, +including YAML frontmatter parsing and ZIP creation. +""" + +from litellm.proxy.skills_endpoints.validation import ( + SkillFrontmatter, + parse_skill_md, + validate_skill_files, +) + +__all__ = [ + "SkillFrontmatter", + "parse_skill_md", + "validate_skill_files", +] diff --git a/litellm/proxy/skills_endpoints/validation.py b/litellm/proxy/skills_endpoints/validation.py new file mode 100644 index 00000000000..bd8e8a566ae --- /dev/null +++ b/litellm/proxy/skills_endpoints/validation.py @@ -0,0 +1,208 @@ +""" +Validation utilities for Gateway Skills API. + +Handles YAML frontmatter parsing and validation from SKILL.md files. +""" + +import io +import re +import zipfile +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field, field_validator + +from litellm._logging import verbose_logger + +# Maximum file size for skill uploads (8MB) +MAX_SKILL_FILE_SIZE = 8 * 1024 * 1024 + +# YAML frontmatter regex pattern +FRONTMATTER_PATTERN = re.compile(r"^---\s*\n(.*?)\n---\s*\n?", re.DOTALL) + + +class SkillFrontmatter(BaseModel): + """ + Pydantic model for SKILL.md YAML frontmatter. + + Validates the frontmatter according to the spec: + - name: required, max 64 characters + - description: optional, max 1024 characters + """ + + name: str = Field(..., max_length=64, description="Skill name (max 64 chars)") + description: Optional[str] = Field( + None, max_length=1024, description="Skill description (max 1024 chars)" + ) + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate skill name is not empty and within limits.""" + if not v or not v.strip(): + raise ValueError("Skill name cannot be empty") + return v.strip() + + @field_validator("description") + @classmethod + def validate_description(cls, v: Optional[str]) -> Optional[str]: + """Validate description is within limits.""" + if v is not None: + return v.strip() if v.strip() else None + return v + + +def parse_yaml_frontmatter(content: str) -> Optional[Dict[str, Any]]: + """ + Parse YAML frontmatter from markdown content. + + Frontmatter is expected in the format: + --- + name: My Skill + description: Does something cool + --- + + Args: + content: The markdown content with potential frontmatter + + Returns: + Dict of frontmatter values, or None if no frontmatter found + """ + try: + import yaml + except ImportError: + verbose_logger.warning( + "PyYAML not installed, cannot parse SKILL.md frontmatter" + ) + return None + + match = FRONTMATTER_PATTERN.match(content) + if not match: + return None + + yaml_content = match.group(1) + try: + return yaml.safe_load(yaml_content) + except yaml.YAMLError as e: + verbose_logger.warning(f"Failed to parse YAML frontmatter: {e}") + return None + + +def parse_skill_md(content: str) -> Tuple[Optional[SkillFrontmatter], str]: + """ + Parse SKILL.md content to extract frontmatter and body. + + Args: + content: The full SKILL.md content + + Returns: + Tuple of (SkillFrontmatter or None, body content without frontmatter) + """ + # Try to parse frontmatter + frontmatter_data = parse_yaml_frontmatter(content) + + # Extract body (content after frontmatter) + body = content + match = FRONTMATTER_PATTERN.match(content) + if match: + body = content[match.end() :].strip() + + # Validate frontmatter if present + if frontmatter_data: + try: + frontmatter = SkillFrontmatter(**frontmatter_data) + return frontmatter, body + except Exception as e: + verbose_logger.warning(f"Invalid SKILL.md frontmatter: {e}") + return None, body + + return None, body + + +def validate_skill_files( + files: List[Tuple[str, bytes]], +) -> Tuple[Optional[bytes], Optional[SkillFrontmatter], str, List[str]]: + """ + Validate uploaded skill files and create a ZIP archive. + + Validates: + - SKILL.md is present + - Total size is under 8MB limit + - YAML frontmatter is valid (if present) + + Args: + files: List of (filename, content) tuples + + Returns: + Tuple of (zip_content, frontmatter, body_content, error_messages) + If errors, zip_content and frontmatter will be None + """ + errors: List[str] = [] + skill_md_content: Optional[str] = None + total_size = 0 + + # Check for SKILL.md and calculate total size + for filename, content in files: + total_size += len(content) + if filename == "SKILL.md" or filename.endswith("/SKILL.md"): + try: + skill_md_content = content.decode("utf-8") + except UnicodeDecodeError: + errors.append("SKILL.md must be valid UTF-8 text") + + # Validate SKILL.md presence + if skill_md_content is None: + errors.append("SKILL.md is required at the root of the skill files") + + # Validate total size + if total_size > MAX_SKILL_FILE_SIZE: + errors.append( + f"Total file size ({total_size / 1024 / 1024:.2f}MB) exceeds " + f"limit ({MAX_SKILL_FILE_SIZE / 1024 / 1024}MB)" + ) + + if errors: + return None, None, "", errors + + # Parse SKILL.md frontmatter + assert skill_md_content is not None # We checked above + frontmatter, body = parse_skill_md(skill_md_content) + + # Validate frontmatter is present and valid + if frontmatter is None: + errors.append( + "SKILL.md must have valid YAML frontmatter with at least a 'name' field" + ) + return None, None, body, errors + + # Create ZIP archive + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + for filename, content in files: + zf.writestr(filename, content) + + return zip_buffer.getvalue(), frontmatter, body, [] + + +def extract_skill_name_from_zip(zip_content: bytes) -> Optional[str]: + """ + Extract skill name from a ZIP file containing SKILL.md. + + Args: + zip_content: The ZIP file content as bytes + + Returns: + The skill name from frontmatter, or None if not found + """ + try: + zip_buffer = io.BytesIO(zip_content) + with zipfile.ZipFile(zip_buffer, "r") as zf: + for name in zf.namelist(): + if name.endswith("SKILL.md"): + content = zf.read(name).decode("utf-8") + frontmatter, _ = parse_skill_md(content) + if frontmatter: + return frontmatter.name + except Exception as e: + verbose_logger.warning(f"Failed to extract skill name from ZIP: {e}") + + return None diff --git a/tests/litellm/proxy/skills_endpoints/__init__.py b/tests/litellm/proxy/skills_endpoints/__init__.py new file mode 100644 index 00000000000..99e4c9350d9 --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/__init__.py @@ -0,0 +1 @@ +"""Tests for Skills API endpoints.""" diff --git a/tests/litellm/proxy/skills_endpoints/test_create_skill_endpoint.py b/tests/litellm/proxy/skills_endpoints/test_create_skill_endpoint.py new file mode 100644 index 00000000000..0c7c93ab66e --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/test_create_skill_endpoint.py @@ -0,0 +1,346 @@ +""" +Tests for the create skill endpoint form data parsing. + +Simulates actual curl/multipart uploads to verify the endpoint +correctly handles file uploads in litellm mode. +""" + +import io +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import UploadFile +from starlette.datastructures import Headers + +from litellm.proxy._types import LiteLLM_SkillsTable, UserAPIKeyAuth + + +def _make_upload_file(filename: str, content: bytes) -> UploadFile: + """Create a FastAPI UploadFile matching what multipart form parsing produces.""" + return UploadFile( + filename=filename, + file=io.BytesIO(content), + headers=Headers({"content-type": "application/octet-stream"}), + ) + + +def _make_db_skill(**overrides) -> LiteLLM_SkillsTable: + """Create a mock DB skill record.""" + from datetime import datetime + + defaults = { + "skill_id": "litellm_skill_test123", + "display_title": "Test Skill", + "description": None, + "instructions": "Test instructions", + "source": "custom", + "file_content": b"fake-zip", + "file_name": "skill.zip", + "file_type": "application/zip", + "created_at": datetime(2026, 3, 21), + "updated_at": datetime(2026, 3, 21), + } + defaults.update(overrides) + return LiteLLM_SkillsTable(**defaults) + + +SKILL_MD_CONTENT = b"""--- +name: test-skill +description: A test skill +--- + +Test instructions here. +""" + + +class TestCreateSkillFormParsing: + """Tests that simulate actual curl multipart uploads.""" + + @pytest.mark.asyncio + async def test_files_bracket_key_single_upload(self): + """ + Simulate: curl -F "files[]=@SKILL.md;filename=skill/SKILL.md" + + get_form_data strips [] so key becomes "files" with value in a list. + """ + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + upload = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + + mock_request = MagicMock() + # get_form_data strips [] and appends to list + mock_request.form = AsyncMock(return_value={"files[]": upload}) + + mock_db_skill = _make_db_skill(display_title="test-skill") + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": [upload], "display_title": "Test Skill"}, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ): + user = UserAPIKeyAuth(api_key="test-key") + result = await _handle_litellm_create_skill(mock_request, user) + + assert result.id == "litellm_skill_test123" + + @pytest.mark.asyncio + async def test_files_key_without_brackets(self): + """ + Test that files under plain "files" key also works. + """ + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + upload = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + mock_request = MagicMock() + + mock_db_skill = _make_db_skill() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": [upload], "display_title": "Test"}, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ): + user = UserAPIKeyAuth(api_key="test-key") + result = await _handle_litellm_create_skill(mock_request, user) + + assert result.id == "litellm_skill_test123" + + @pytest.mark.asyncio + async def test_single_upload_file_not_in_list(self): + """ + Test that a single UploadFile (not wrapped in list) is handled. + """ + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + upload = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + mock_request = MagicMock() + + mock_db_skill = _make_db_skill() + + # Single UploadFile, not in a list + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": upload, "display_title": "Test"}, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ): + user = UserAPIKeyAuth(api_key="test-key") + result = await _handle_litellm_create_skill(mock_request, user) + + assert result.id == "litellm_skill_test123" + + @pytest.mark.asyncio + async def test_no_files_returns_400(self): + """Test that missing files returns 400.""" + from fastapi import HTTPException + + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + mock_request = MagicMock() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"display_title": "Test"}, + ): + user = UserAPIKeyAuth(api_key="test-key") + with pytest.raises(HTTPException) as exc_info: + await _handle_litellm_create_skill(mock_request, user) + + assert exc_info.value.status_code == 400 + assert "No files provided" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_empty_files_list_returns_400(self): + """Test that empty files list returns 400.""" + from fastapi import HTTPException + + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + mock_request = MagicMock() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": [], "display_title": "Test"}, + ): + user = UserAPIKeyAuth(api_key="test-key") + with pytest.raises(HTTPException) as exc_info: + await _handle_litellm_create_skill(mock_request, user) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + async def test_tuple_format_files(self): + """Test that (filename, content) tuple format works.""" + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + mock_request = MagicMock() + mock_db_skill = _make_db_skill() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={ + "files": [("skill/SKILL.md", SKILL_MD_CONTENT)], + "display_title": "Test", + }, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ): + user = UserAPIKeyAuth(api_key="test-key") + result = await _handle_litellm_create_skill(mock_request, user) + + assert result.id == "litellm_skill_test123" + + @pytest.mark.asyncio + async def test_multiple_files(self): + """Test uploading SKILL.md plus additional files.""" + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + skill_md = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + helper_py = _make_upload_file("skill/helper.py", b"def helper(): return 42") + mock_request = MagicMock() + mock_db_skill = _make_db_skill() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={ + "files": [skill_md, helper_py], + "display_title": "Multi File Skill", + }, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ) as mock_create: + user = UserAPIKeyAuth(api_key="test-key") + result = await _handle_litellm_create_skill(mock_request, user) + + assert result.id == "litellm_skill_test123" + # Verify the create was called with file content + call_args = mock_create.call_args + data = call_args[1]["data"] + assert data.file_content is not None + + @pytest.mark.asyncio + async def test_invalid_frontmatter_returns_400(self): + """Test that SKILL.md with invalid frontmatter returns 400.""" + from fastapi import HTTPException + + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + bad_skill_md = b"""--- +description: Missing required name field +--- + +Body content. +""" + upload = _make_upload_file("skill/SKILL.md", bad_skill_md) + mock_request = MagicMock() + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": [upload], "display_title": "Test"}, + ): + user = UserAPIKeyAuth(api_key="test-key") + with pytest.raises(HTTPException) as exc_info: + await _handle_litellm_create_skill(mock_request, user) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + async def test_display_title_override(self): + """Test that display_title from form overrides frontmatter name.""" + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + upload = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + mock_request = MagicMock() + mock_db_skill = _make_db_skill(display_title="Custom Title") + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={ + "files": [upload], + "display_title": "Custom Title", + }, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ) as mock_create: + user = UserAPIKeyAuth(api_key="test-key") + await _handle_litellm_create_skill(mock_request, user) + + call_args = mock_create.call_args + data = call_args[1]["data"] + assert data.display_title == "Custom Title" + + @pytest.mark.asyncio + async def test_display_title_falls_back_to_frontmatter_name(self): + """Test that without display_title override, frontmatter name is used.""" + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + _handle_litellm_create_skill, + ) + + upload = _make_upload_file("skill/SKILL.md", SKILL_MD_CONTENT) + mock_request = MagicMock() + mock_db_skill = _make_db_skill(display_title="test-skill") + + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints.get_form_data", + new_callable=AsyncMock, + return_value={"files": [upload]}, + ): + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler.create_skill", + new_callable=AsyncMock, + return_value=mock_db_skill, + ) as mock_create: + user = UserAPIKeyAuth(api_key="test-key") + await _handle_litellm_create_skill(mock_request, user) + + call_args = mock_create.call_args + data = call_args[1]["data"] + # Should fall back to frontmatter name "test-skill" + assert data.display_title == "test-skill" diff --git a/tests/litellm/proxy/skills_endpoints/test_skills_handler.py b/tests/litellm/proxy/skills_endpoints/test_skills_handler.py new file mode 100644 index 00000000000..baf30c0bc50 --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/test_skills_handler.py @@ -0,0 +1,484 @@ +""" +Tests for LiteLLMSkillsHandler - database CRUD operations. + +Tests skill creation, listing, retrieval, and deletion through mocked Prisma client. +""" + +import base64 +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.proxy._types import LiteLLM_SkillsTable, NewSkillRequest + + +def _make_prisma_skill( + skill_id: str = "litellm_skill_test123", + display_title: str = "Test Skill", + description: str = "A test skill", + instructions: str = "Do testing.", + source: str = "custom", + file_content: bytes = b"fake-zip", + file_name: str = "skill.zip", + file_type: str = "application/zip", +): + """Create a mock Prisma skill record.""" + mock = MagicMock() + mock.model_dump.return_value = { + "skill_id": skill_id, + "display_title": display_title, + "description": description, + "instructions": instructions, + "source": source, + "latest_version": None, + "file_content": base64.b64encode(file_content).decode("utf-8"), + "file_name": file_name, + "file_type": file_type, + "metadata": None, + "created_at": datetime(2026, 3, 21), + "created_by": "user1", + "updated_at": datetime(2026, 3, 21), + "updated_by": "user1", + } + return mock + + +class TestCreateSkill: + """Tests for LiteLLMSkillsHandler.create_skill.""" + + @pytest.mark.asyncio + async def test_create_skill_success(self): + """Test successful skill creation stores data in DB.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_skillstable.create = AsyncMock( + return_value=_make_prisma_skill() + ) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + request = NewSkillRequest( + display_title="Test Skill", + description="A test skill", + instructions="Do testing.", + file_content=b"fake-zip", + file_name="skill.zip", + file_type="application/zip", + ) + + result = await LiteLLMSkillsHandler.create_skill( + data=request, user_id="user1" + ) + + assert isinstance(result, LiteLLM_SkillsTable) + assert result.display_title == "Test Skill" + assert result.description == "A test skill" + mock_prisma.db.litellm_skillstable.create.assert_called_once() + + @pytest.mark.asyncio + async def test_create_skill_generates_litellm_prefix_id(self): + """Test that created skill IDs have litellm_skill_ prefix.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_skillstable.create = AsyncMock( + return_value=_make_prisma_skill() + ) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + request = NewSkillRequest( + display_title="Test", + instructions="test", + ) + + await LiteLLMSkillsHandler.create_skill(data=request) + + call_args = mock_prisma.db.litellm_skillstable.create.call_args + skill_data = call_args[1]["data"] + assert skill_data["skill_id"].startswith("litellm_skill_") + + @pytest.mark.asyncio + async def test_create_skill_no_prisma_raises(self): + """Test that creating a skill without Prisma client raises ValueError.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + with patch( + "litellm.llms.litellm_proxy.skills.handler.LiteLLMSkillsHandler._get_prisma_client", + new_callable=AsyncMock, + side_effect=ValueError("Prisma client is not initialized"), + ): + request = NewSkillRequest(display_title="Test", instructions="test") + + with pytest.raises(ValueError, match="Prisma client"): + await LiteLLMSkillsHandler.create_skill(data=request) + + @pytest.mark.asyncio + async def test_create_skill_duplicate_title_raises(self): + """Test that creating a skill with a duplicate display_title raises ValueError.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + existing_skill = MagicMock() + existing_skill.skill_id = "litellm_skill_existing" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_first = AsyncMock( + return_value=existing_skill + ) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + request = NewSkillRequest( + display_title="Duplicate Name", + instructions="test", + ) + + with pytest.raises(ValueError, match="already exists"): + await LiteLLMSkillsHandler.create_skill(data=request) + + # Should never reach create + mock_prisma.db.litellm_skillstable.create.assert_not_called() + + +class TestListSkills: + """Tests for LiteLLMSkillsHandler.list_skills.""" + + @pytest.mark.asyncio + async def test_list_skills_returns_records(self): + """Test listing skills returns all records.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_many = AsyncMock( + return_value=[ + _make_prisma_skill(skill_id="litellm_skill_1", display_title="Skill 1"), + _make_prisma_skill(skill_id="litellm_skill_2", display_title="Skill 2"), + ] + ) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + results = await LiteLLMSkillsHandler.list_skills(limit=10) + + assert len(results) == 2 + assert all(isinstance(r, LiteLLM_SkillsTable) for r in results) + + @pytest.mark.asyncio + async def test_list_skills_empty(self): + """Test listing skills returns empty list when none exist.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_many = AsyncMock(return_value=[]) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + results = await LiteLLMSkillsHandler.list_skills() + assert results == [] + + @pytest.mark.asyncio + async def test_list_skills_with_pagination(self): + """Test listing skills passes limit and offset to Prisma.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_many = AsyncMock(return_value=[]) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + await LiteLLMSkillsHandler.list_skills(limit=5, offset=10) + + call_args = mock_prisma.db.litellm_skillstable.find_many.call_args + assert call_args[1]["take"] == 5 + assert call_args[1]["skip"] == 10 + + +class TestGetSkill: + """Tests for LiteLLMSkillsHandler.get_skill.""" + + @pytest.mark.asyncio + async def test_get_skill_found(self): + """Test getting a skill that exists.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock( + return_value=_make_prisma_skill(skill_id="litellm_skill_abc") + ) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + result = await LiteLLMSkillsHandler.get_skill("litellm_skill_abc") + + assert isinstance(result, LiteLLM_SkillsTable) + assert result.skill_id == "litellm_skill_abc" + + @pytest.mark.asyncio + async def test_get_skill_not_found(self): + """Test getting a skill that doesn't exist raises ValueError.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock(return_value=None) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + with pytest.raises(ValueError, match="Skill not found"): + await LiteLLMSkillsHandler.get_skill("nonexistent") + + +class TestDeleteSkill: + """Tests for LiteLLMSkillsHandler.delete_skill.""" + + @pytest.mark.asyncio + async def test_delete_skill_success(self): + """Test deleting an existing skill.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock( + return_value=_make_prisma_skill(skill_id="litellm_skill_del") + ) + mock_prisma.db.litellm_skillstable.delete = AsyncMock() + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + result = await LiteLLMSkillsHandler.delete_skill("litellm_skill_del") + + assert result["id"] == "litellm_skill_del" + assert result["type"] == "skill_deleted" + mock_prisma.db.litellm_skillstable.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_skill_not_found(self): + """Test deleting a nonexistent skill raises ValueError.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock(return_value=None) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + with pytest.raises(ValueError, match="Skill not found"): + await LiteLLMSkillsHandler.delete_skill("nonexistent") + + +class TestFetchSkillFromDb: + """Tests for LiteLLMSkillsHandler.fetch_skill_from_db (convenience method).""" + + @pytest.mark.asyncio + async def test_fetch_returns_none_on_not_found(self): + """Test that fetch_skill_from_db returns None instead of raising.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock(return_value=None) + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + result = await LiteLLMSkillsHandler.fetch_skill_from_db("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_fetch_returns_none_on_error(self): + """Test that fetch_skill_from_db returns None on unexpected errors.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + side_effect=RuntimeError("DB connection lost"), + ): + result = await LiteLLMSkillsHandler.fetch_skill_from_db("any_id") + assert result is None + + +class TestProviderSkillIds: + """Tests for provider skill ID save/retrieve.""" + + @pytest.mark.asyncio + async def test_save_provider_skill_id(self): + """Test saving an Anthropic skill ID for a LiteLLM skill.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_skill = MagicMock() + mock_skill.metadata = {} + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock( + return_value=mock_skill + ) + mock_prisma.db.litellm_skillstable.update = AsyncMock() + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + await LiteLLMSkillsHandler.save_provider_skill_id( + skill_id="litellm_skill_abc", + provider="anthropic", + provider_skill_id="sk_ant_123", + ) + + call_args = mock_prisma.db.litellm_skillstable.update.call_args + assert call_args[1]["where"] == {"skill_id": "litellm_skill_abc"} + import json + + saved_metadata = json.loads(call_args[1]["data"]["metadata"]) + assert saved_metadata["_provider_skill_ids"]["anthropic"] == "sk_ant_123" + + @pytest.mark.asyncio + async def test_save_provider_skill_id_preserves_existing_metadata(self): + """Test that saving a provider ID doesn't clobber existing metadata.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_skill = MagicMock() + mock_skill.metadata = {"custom_key": "custom_value"} + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock( + return_value=mock_skill + ) + mock_prisma.db.litellm_skillstable.update = AsyncMock() + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + await LiteLLMSkillsHandler.save_provider_skill_id( + skill_id="litellm_skill_abc", + provider="anthropic", + provider_skill_id="sk_ant_456", + ) + + call_args = mock_prisma.db.litellm_skillstable.update.call_args + import json + + saved_metadata = json.loads(call_args[1]["data"]["metadata"]) + assert saved_metadata["custom_key"] == "custom_value" + assert saved_metadata["_provider_skill_ids"]["anthropic"] == "sk_ant_456" + + @pytest.mark.asyncio + async def test_save_provider_skill_id_skill_not_found(self): + """Test that saving provider ID for missing skill doesn't raise.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + mock_prisma = MagicMock() + mock_prisma.db.litellm_skillstable.find_unique = AsyncMock(return_value=None) + mock_prisma.db.litellm_skillstable.update = AsyncMock() + + with patch.object( + LiteLLMSkillsHandler, + "_get_prisma_client", + new_callable=AsyncMock, + return_value=mock_prisma, + ): + # Should not raise + await LiteLLMSkillsHandler.save_provider_skill_id( + skill_id="nonexistent", + provider="anthropic", + provider_skill_id="sk_ant_789", + ) + + mock_prisma.db.litellm_skillstable.update.assert_not_called() + + def test_get_provider_skill_id_found(self): + """Test retrieving a saved Anthropic skill ID.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + skill = LiteLLM_SkillsTable( + skill_id="litellm_skill_abc", + metadata={"_provider_skill_ids": {"anthropic": "sk_ant_123"}}, + ) + + result = LiteLLMSkillsHandler.get_provider_skill_id(skill, "anthropic") + assert result == "sk_ant_123" + + def test_get_provider_skill_id_not_found(self): + """Test retrieving provider ID when none saved.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + skill = LiteLLM_SkillsTable( + skill_id="litellm_skill_abc", + metadata={}, + ) + + result = LiteLLMSkillsHandler.get_provider_skill_id(skill, "anthropic") + assert result is None + + def test_get_provider_skill_id_no_metadata(self): + """Test retrieving provider ID when metadata is None.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + skill = LiteLLM_SkillsTable( + skill_id="litellm_skill_abc", + metadata=None, + ) + + result = LiteLLMSkillsHandler.get_provider_skill_id(skill, "anthropic") + assert result is None + + def test_get_provider_skill_id_different_provider(self): + """Test that provider IDs are namespaced per provider.""" + from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler + + skill = LiteLLM_SkillsTable( + skill_id="litellm_skill_abc", + metadata={"_provider_skill_ids": {"anthropic": "sk_ant_123"}}, + ) + + assert LiteLLMSkillsHandler.get_provider_skill_id(skill, "anthropic") == "sk_ant_123" + assert LiteLLMSkillsHandler.get_provider_skill_id(skill, "openai") is None diff --git a/tests/litellm/proxy/skills_endpoints/test_skills_injection_hook.py b/tests/litellm/proxy/skills_endpoints/test_skills_injection_hook.py new file mode 100644 index 00000000000..3562bb9e6de --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/test_skills_injection_hook.py @@ -0,0 +1,944 @@ +""" +Tests for SkillsInjectionHook - pre-call skill processing and system prompt injection. + +Covers: +- Request opt-in via container.skills +- OpenAI system prompt injection (with and without existing system message) +- Anthropic tool conversion +- Provider fallback behavior +- Skills skipped for non-completion call types +""" + +import io +import zipfile +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest + +from litellm.proxy._types import LiteLLM_SkillsTable, UserAPIKeyAuth +from litellm.proxy.hooks.litellm_skills.main import SkillsInjectionHook + + +def _make_skill( + skill_id: str = "litellm_skill_test1", + display_title: str = "Test Skill", + description: str = "A test skill", + instructions: str = "Follow these instructions for testing.", + file_content: bytes | None = None, +) -> LiteLLM_SkillsTable: + """Create a LiteLLM_SkillsTable for testing.""" + if file_content is None: + # Create a minimal ZIP with SKILL.md + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr( + "test_skill/SKILL.md", + "---\nname: test-skill\ndescription: A test\n---\n\n" + "Follow these instructions for testing.", + ) + file_content = buf.getvalue() + + return LiteLLM_SkillsTable( + skill_id=skill_id, + display_title=display_title, + description=description, + instructions=instructions, + source="custom", + file_content=file_content, + file_name="skill.zip", + file_type="application/zip", + created_at=datetime(2026, 3, 21), + updated_at=datetime(2026, 3, 21), + ) + + +def _make_skill_with_code( + skill_id: str = "litellm_skill_code1", + display_title: str = "Code Skill", +) -> LiteLLM_SkillsTable: + """Create a skill with Python files in the ZIP.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr( + "code_skill/SKILL.md", + "---\nname: code-skill\n---\n\nA skill with code.", + ) + zf.writestr("code_skill/main.py", "def run(): return 42") + + return LiteLLM_SkillsTable( + skill_id=skill_id, + display_title=display_title, + instructions="A skill with code.", + source="custom", + file_content=buf.getvalue(), + file_name="skill.zip", + file_type="application/zip", + created_at=datetime(2026, 3, 21), + updated_at=datetime(2026, 3, 21), + ) + + +class TestPreCallHookOptIn: + """Tests for the opt-in activation model via container.skills.""" + + @pytest.mark.asyncio + async def test_no_container_passes_through(self): + """Test that requests without container are unchanged.""" + hook = SkillsInjectionHook() + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + assert result == data + assert "tools" not in result + + @pytest.mark.asyncio + async def test_empty_container_passes_through(self): + """Test that requests with empty container are unchanged.""" + hook = SkillsInjectionHook() + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": {}, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + assert result == data + + @pytest.mark.asyncio + async def test_non_completion_call_type_skipped(self): + """Test that non-completion call types are not processed.""" + hook = SkillsInjectionHook() + data = { + "model": "gpt-4o", + "container": {"skills": [{"skill_id": "litellm_skill_x"}]}, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="embedding", + ) + + assert result == data + + @pytest.mark.asyncio + async def test_litellm_skill_fetched_from_db(self): + """Test that litellm_* skills are fetched from DB.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_test1", "type": "anthropic"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + # container should be removed after processing + assert "container" not in result + + @pytest.mark.asyncio + async def test_missing_skill_raises_404(self): + """Test that referencing a nonexistent skill raises 404.""" + from fastapi import HTTPException + + hook = SkillsInjectionHook() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=None + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_nonexistent"}] + }, + } + + with pytest.raises(HTTPException) as exc_info: + await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + assert exc_info.value.status_code == 404 + assert "litellm_skill_nonexistent" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_anthropic_native_skill_on_native_provider(self): + """Test that Anthropic skill_ IDs pass through on native providers.""" + hook = SkillsInjectionHook() + + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="anthropic", + ): + data = { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "skill_01abc123", "type": "anthropic"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + # Anthropic native skill should pass through + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_anthropic_native_skill_on_non_native_provider_fails(self): + """Test that Anthropic skill_ IDs fail on non-native providers.""" + from fastapi import HTTPException + + hook = SkillsInjectionHook() + + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "skill_01abc123", "type": "anthropic"}] + }, + } + + with pytest.raises(HTTPException) as exc_info: + await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + assert exc_info.value.status_code == 400 + assert "does not support native skills" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_invalid_skill_id_prefix_fails(self): + """Test that skill IDs with unrecognized prefixes fail.""" + from fastapi import HTTPException + + hook = SkillsInjectionHook() + + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "random_garbage_id", "type": "anthropic"}] + }, + } + + with pytest.raises(HTTPException) as exc_info: + await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + assert exc_info.value.status_code == 400 + assert "Invalid skill_id" in str(exc_info.value.detail) + + +class TestSystemPromptInjection: + """Tests for OpenAI-style system prompt injection.""" + + def test_inject_into_new_system_message(self): + """Test injecting skill content when no system message exists.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = handler.inject_skill_content_to_messages( + data, + ["## Skill: Test\n\nDo testing."], + use_anthropic_format=False, + ) + + messages = result["messages"] + assert messages[0]["role"] == "system" + assert "Available Skills" in messages[0]["content"] + assert "Do testing." in messages[0]["content"] + assert messages[1]["role"] == "user" + + def test_inject_into_existing_system_message(self): + """Test injecting skill content appends to existing system message.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + } + + result = handler.inject_skill_content_to_messages( + data, + ["## Skill: Test\n\nDo testing."], + use_anthropic_format=False, + ) + + messages = result["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"].startswith("You are a helpful assistant.") + assert "Available Skills" in messages[0]["content"] + assert "Do testing." in messages[0]["content"] + + def test_inject_anthropic_format_uses_system_param(self): + """Test Anthropic format injects into top-level 'system' parameter.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = handler.inject_skill_content_to_messages( + data, + ["## Skill: Test\n\nDo testing."], + use_anthropic_format=True, + ) + + assert "system" in result + assert "Available Skills" in result["system"] + assert "Do testing." in result["system"] + + def test_inject_anthropic_format_appends_to_existing_system(self): + """Test Anthropic format appends to existing system parameter.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Hello"}], + } + + result = handler.inject_skill_content_to_messages( + data, + ["## Skill: Test\n\nDo testing."], + use_anthropic_format=True, + ) + + assert result["system"].startswith("You are a helpful assistant.") + assert "Do testing." in result["system"] + + def test_inject_multiple_skills(self): + """Test injecting multiple skills creates separate sections.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = handler.inject_skill_content_to_messages( + data, + [ + "## Skill: Alpha\n\nAlpha instructions.", + "## Skill: Beta\n\nBeta instructions.", + ], + use_anthropic_format=False, + ) + + system_content = result["messages"][0]["content"] + assert "Alpha instructions." in system_content + assert "Beta instructions." in system_content + + def test_inject_empty_list_no_change(self): + """Test that empty skill list doesn't modify data.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = handler.inject_skill_content_to_messages( + data, [], use_anthropic_format=False + ) + + assert len(result["messages"]) == 1 + assert result["messages"][0]["role"] == "user" + + +class TestSkillApplicator: + """Tests for SkillApplicator provider-specific strategies.""" + + @pytest.mark.asyncio + async def test_openai_uses_system_prompt_strategy(self): + """Test OpenAI provider uses system prompt injection.""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + skill = _make_skill() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await applicator.apply_skills(data, [skill], provider="openai") + + # Should have injected into system message + assert result["messages"][0]["role"] == "system" + assert "Test Skill" in result["messages"][0]["content"] + + @pytest.mark.asyncio + async def test_anthropic_uses_tool_conversion_strategy(self): + """Test Anthropic provider uses tool conversion strategy.""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + skill = _make_skill() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await applicator.apply_skills(data, [skill], provider="anthropic") + + # Should have added tools with Anthropic format (name, description, input_schema) + assert "tools" in result + assert len(result["tools"]) >= 1 + tool = result["tools"][0] + assert "name" in tool + assert "input_schema" in tool + + @pytest.mark.asyncio + async def test_azure_uses_system_prompt_strategy(self): + """Test Azure provider uses system prompt injection (same as OpenAI).""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + skill = _make_skill() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await applicator.apply_skills(data, [skill], provider="azure") + + assert result["messages"][0]["role"] == "system" + assert "Test Skill" in result["messages"][0]["content"] + + @pytest.mark.asyncio + async def test_unknown_provider_defaults_to_system_prompt(self): + """Test unknown providers default to system prompt injection.""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + skill = _make_skill() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await applicator.apply_skills( + data, [skill], provider="unknown_provider" + ) + + assert result["messages"][0]["role"] == "system" + + @pytest.mark.asyncio + async def test_empty_skills_list_no_change(self): + """Test that empty skills list doesn't modify data.""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + + data = { + "messages": [{"role": "user", "content": "Hello"}], + } + + result = await applicator.apply_skills(data, [], provider="openai") + + assert len(result["messages"]) == 1 + assert result["messages"][0]["role"] == "user" + + def test_supports_native_skills(self): + """Test native skills support detection.""" + from litellm.llms.litellm_proxy.skills.skill_applicator import SkillApplicator + + applicator = SkillApplicator() + + assert applicator.supports_native_skills("anthropic") is True + assert applicator.supports_native_skills("openai") is False + assert applicator.supports_native_skills("azure") is False + assert applicator.supports_native_skills("bedrock") is False + assert applicator.supports_native_skills("gemini") is False + assert applicator.supports_native_skills("some_new_provider") is False + + +class TestSkillContentExtraction: + """Tests for skill content extraction from ZIP files.""" + + def test_extract_skill_md_from_zip(self): + """Test extracting SKILL.md content from ZIP.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + skill = _make_skill() + + content = handler.extract_skill_content(skill) + + assert content is not None + assert "Follow these instructions" in content + + def test_extract_fallback_to_instructions(self): + """Test fallback to instructions field when no file_content.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + skill = LiteLLM_SkillsTable( + skill_id="litellm_skill_nf", + instructions="Fallback instructions", + source="custom", + ) + + content = handler.extract_skill_content(skill) + + assert content == "Fallback instructions" + + def test_extract_all_files_from_zip(self): + """Test extracting all files from skill ZIP.""" + from litellm.llms.litellm_proxy.skills.prompt_injection import ( + SkillPromptInjectionHandler, + ) + + handler = SkillPromptInjectionHandler() + skill = _make_skill_with_code() + + files = handler.extract_all_files(skill) + + assert "SKILL.md" in files + assert "main.py" in files + assert files["main.py"] == b"def run(): return 42" + + +class TestMessagesAPIProcessing: + """Tests for _process_for_messages_api in the hook.""" + + def test_process_removes_container(self): + """Test that processing removes the container field.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": {"skills": [{"skill_id": "litellm_skill_test1"}]}, + } + + result = hook._process_for_messages_api( + data=data, litellm_skills=[skill], use_anthropic_format=False + ) + + assert "container" not in result + + def test_process_adds_tools(self): + """Test that processing adds skill as Anthropic-style tool.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + } + + result = hook._process_for_messages_api( + data=data, litellm_skills=[skill], use_anthropic_format=False + ) + + assert "tools" in result + # Should have skill tool + code_execution tool + assert len(result["tools"]) >= 1 + + def test_process_injects_system_prompt(self): + """Test that processing injects skill content into system prompt.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + } + + result = hook._process_for_messages_api( + data=data, litellm_skills=[skill], use_anthropic_format=False + ) + + # Check system message was injected + messages = result["messages"] + system_msgs = [m for m in messages if m.get("role") == "system"] + assert len(system_msgs) > 0 + + def test_process_with_code_files_enables_execution(self): + """Test that skills with code files enable code execution.""" + hook = SkillsInjectionHook() + skill = _make_skill_with_code() + + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Generate something"}], + } + + result = hook._process_for_messages_api( + data=data, litellm_skills=[skill], use_anthropic_format=False + ) + + # Should have code execution enabled in metadata + assert result.get("litellm_metadata", {}).get( + "_litellm_code_execution_enabled" + ) + assert "_skill_files" in result.get("litellm_metadata", {}) + + +class TestPreCallProviderRouting: + """ + Tests that the pre-call hook routes to the correct strategy based on provider. + + Non-native providers (OpenAI, Azure, Bedrock, etc.): + - Skill content injected into system prompt + - No skill tools added + - container removed + + Native providers (Anthropic, azure_ai, databricks): + - Skill converted to Anthropic-style tool + - Skill content injected into system prompt + - Code execution tool added (if skill has files) + - container removed + """ + + @pytest.mark.asyncio + async def test_openai_gets_system_prompt_only(self): + """OpenAI should get system prompt injection, no tools from skill.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_test1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + # System prompt should be injected + system_msgs = [ + m for m in result["messages"] if m.get("role") == "system" + ] + assert len(system_msgs) == 1 + assert "Test Skill" in system_msgs[0]["content"] + + # No tools should be added for OpenAI + assert "tools" not in result + + # container should be removed + assert "container" not in result + + @pytest.mark.asyncio + async def test_azure_gets_system_prompt_only(self): + """Azure should get system prompt injection, no tools from skill.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="azure", + ): + data = { + "model": "azure/gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_test1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + system_msgs = [ + m for m in result["messages"] if m.get("role") == "system" + ] + assert len(system_msgs) == 1 + assert "tools" not in result + assert "container" not in result + + @pytest.mark.asyncio + async def test_bedrock_gets_system_prompt_only(self): + """Bedrock should get system prompt injection, no tools from skill.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="bedrock", + ): + data = { + "model": "bedrock/anthropic.claude-v2", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_test1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + system_msgs = [ + m for m in result["messages"] if m.get("role") == "system" + ] + assert len(system_msgs) == 1 + assert "tools" not in result + assert "container" not in result + + @pytest.mark.asyncio + async def test_anthropic_gets_tools_and_system_prompt(self): + """Anthropic should get tool conversion + system prompt injection.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="anthropic", + ): + data = { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_test1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + # Tools should be added (skill tool) + assert "tools" in result + assert len(result["tools"]) >= 1 + + # container should be removed + assert "container" not in result + + @pytest.mark.asyncio + async def test_openai_preserves_existing_system_message(self): + """OpenAI skill injection should append to existing system message.""" + hook = SkillsInjectionHook() + skill = _make_skill() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + "container": { + "skills": [{"skill_id": "litellm_skill_test1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + system_msg = result["messages"][0] + assert system_msg["role"] == "system" + # Original content preserved + assert system_msg["content"].startswith("You are a helpful assistant.") + # Skill content appended + assert "Test Skill" in system_msg["content"] + + @pytest.mark.asyncio + async def test_openai_no_tools_even_with_code_files(self): + """OpenAI should NOT get tools even if skill has Python files.""" + hook = SkillsInjectionHook() + skill = _make_skill_with_code() + + with patch.object( + hook, "_fetch_skill_from_db", new_callable=AsyncMock, return_value=skill + ): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [{"skill_id": "litellm_skill_code1"}] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + # System prompt should have skill content + system_msgs = [ + m for m in result["messages"] if m.get("role") == "system" + ] + assert len(system_msgs) == 1 + + # No tools — OpenAI just gets the prompt + assert "tools" not in result + + @pytest.mark.asyncio + async def test_multiple_skills_openai(self): + """Multiple skills should all be injected into system prompt for OpenAI.""" + hook = SkillsInjectionHook() + skill1 = _make_skill( + skill_id="litellm_skill_a", + display_title="Skill Alpha", + instructions="Alpha instructions.", + ) + skill2 = _make_skill( + skill_id="litellm_skill_b", + display_title="Skill Beta", + instructions="Beta instructions.", + ) + + async def _fetch(skill_id): + return {"litellm_skill_a": skill1, "litellm_skill_b": skill2}.get(skill_id) + + with patch.object(hook, "_fetch_skill_from_db", side_effect=_fetch): + with patch( + "litellm.llms.litellm_proxy.skills.skill_applicator.get_provider_from_model", + return_value="openai", + ): + data = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "container": { + "skills": [ + {"skill_id": "litellm_skill_a"}, + {"skill_id": "litellm_skill_b"}, + ] + }, + } + + result = await hook.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=None, + data=data, + call_type="completion", + ) + + system_content = result["messages"][0]["content"] + assert "Skill Alpha" in system_content + assert "Skill Beta" in system_content + assert "tools" not in result diff --git a/tests/litellm/proxy/skills_endpoints/test_skills_mode.py b/tests/litellm/proxy/skills_endpoints/test_skills_mode.py new file mode 100644 index 00000000000..4721e3b384d --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/test_skills_mode.py @@ -0,0 +1,236 @@ +""" +Tests for Skills mode switching (litellm vs passthrough). +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestGetSkillsMode: + """Tests for get_skills_mode function.""" + + def test_default_mode_is_passthrough(self): + """Test that default skills_mode is passthrough.""" + with patch.dict("litellm.proxy.proxy_server.general_settings", {}, clear=True): + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skills_mode, + ) + + mode = get_skills_mode() + assert mode == "passthrough" + + def test_litellm_mode(self): + """Test skills_mode='litellm' is recognized.""" + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "litellm"}, + clear=True, + ): + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skills_mode, + ) + + mode = get_skills_mode() + assert mode == "litellm" + + def test_passthrough_mode_explicit(self): + """Test explicit skills_mode='passthrough' is recognized.""" + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "passthrough"}, + clear=True, + ): + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skills_mode, + ) + + mode = get_skills_mode() + assert mode == "passthrough" + + def test_invalid_mode_defaults_to_passthrough(self): + """Test that invalid skills_mode defaults to passthrough.""" + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "invalid_mode"}, + clear=True, + ): + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skills_mode, + ) + + mode = get_skills_mode() + assert mode == "passthrough" + + def test_none_mode_defaults_to_passthrough(self): + """Test that skills_mode=None defaults to passthrough.""" + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": None}, + clear=True, + ): + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skills_mode, + ) + + mode = get_skills_mode() + assert mode == "passthrough" + + +class TestSkillsEndpointsModeRouting: + """Tests for endpoint routing based on skills_mode.""" + + @pytest.mark.asyncio + async def test_create_skill_litellm_mode_routes_to_handler(self): + """Test that create_skill in litellm mode calls _handle_litellm_create_skill.""" + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.llms.anthropic_skills import Skill + + mock_skill = Skill( + id="litellm_skill_test123", + display_title="Test Skill", + source="litellm", + created_at="2026-03-21T00:00:00Z", + updated_at="2026-03-21T00:00:00Z", + ) + + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "litellm"}, + clear=True, + ): + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints._handle_litellm_create_skill", + new_callable=AsyncMock, + return_value=mock_skill, + ) as mock_handler: + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + create_skill, + ) + + mock_request = MagicMock() + mock_response = MagicMock() + mock_user = UserAPIKeyAuth(api_key="test-key") + + result = await create_skill( + fastapi_response=mock_response, + request=mock_request, + user_api_key_dict=mock_user, + ) + + mock_handler.assert_called_once() + assert result == mock_skill + + @pytest.mark.asyncio + async def test_list_skills_litellm_mode_routes_to_handler(self): + """Test that list_skills in litellm mode calls _handle_litellm_list_skills.""" + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.llms.anthropic_skills import ListSkillsResponse + + mock_response = ListSkillsResponse(data=[], has_more=False) + + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "litellm"}, + clear=True, + ): + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints._handle_litellm_list_skills", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_handler: + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + list_skills, + ) + + mock_request = MagicMock() + mock_fastapi_response = MagicMock() + mock_user = UserAPIKeyAuth(api_key="test-key") + + result = await list_skills( + fastapi_response=mock_fastapi_response, + request=mock_request, + limit=20, + user_api_key_dict=mock_user, + ) + + mock_handler.assert_called_once_with(limit=20, page=None) + assert result == mock_response + + @pytest.mark.asyncio + async def test_get_skill_litellm_mode_routes_to_handler(self): + """Test that get_skill in litellm mode calls _handle_litellm_get_skill.""" + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.llms.anthropic_skills import Skill + + mock_skill = Skill( + id="litellm_skill_123", + display_title="Test", + source="litellm", + created_at="2026-03-21T00:00:00Z", + updated_at="2026-03-21T00:00:00Z", + ) + + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "litellm"}, + clear=True, + ): + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints._handle_litellm_get_skill", + new_callable=AsyncMock, + return_value=mock_skill, + ) as mock_handler: + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + get_skill, + ) + + mock_request = MagicMock() + mock_response = MagicMock() + mock_user = UserAPIKeyAuth(api_key="test-key") + + result = await get_skill( + skill_id="litellm_skill_123", + fastapi_response=mock_response, + request=mock_request, + user_api_key_dict=mock_user, + ) + + mock_handler.assert_called_once_with("litellm_skill_123") + assert result == mock_skill + + @pytest.mark.asyncio + async def test_delete_skill_litellm_mode_routes_to_handler(self): + """Test that delete_skill in litellm mode calls _handle_litellm_delete_skill.""" + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.llms.anthropic_skills import DeleteSkillResponse + + mock_response = DeleteSkillResponse(id="litellm_skill_123", type="skill_deleted") + + with patch.dict( + "litellm.proxy.proxy_server.general_settings", + {"skills_mode": "litellm"}, + clear=True, + ): + with patch( + "litellm.proxy.anthropic_endpoints.skills_endpoints._handle_litellm_delete_skill", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_handler: + from litellm.proxy.anthropic_endpoints.skills_endpoints import ( + delete_skill, + ) + + mock_request = MagicMock() + mock_fastapi_response = MagicMock() + mock_user = UserAPIKeyAuth(api_key="test-key") + + result = await delete_skill( + skill_id="litellm_skill_123", + fastapi_response=mock_fastapi_response, + request=mock_request, + user_api_key_dict=mock_user, + ) + + mock_handler.assert_called_once_with("litellm_skill_123") + assert result == mock_response diff --git a/tests/litellm/proxy/skills_endpoints/test_validation.py b/tests/litellm/proxy/skills_endpoints/test_validation.py new file mode 100644 index 00000000000..29d590f7a57 --- /dev/null +++ b/tests/litellm/proxy/skills_endpoints/test_validation.py @@ -0,0 +1,275 @@ +""" +Tests for Skills validation utilities (YAML frontmatter parsing, file validation). +""" + +import pytest + +from litellm.proxy.skills_endpoints.validation import ( + SkillFrontmatter, + parse_skill_md, + validate_skill_files, +) + + +class TestParseSKillMd: + """Tests for parse_skill_md function.""" + + def test_parse_valid_frontmatter(self): + """Test parsing valid YAML frontmatter.""" + content = """--- +name: Test Skill +description: A test skill for testing +--- + +# Instructions + +Use this skill to do testing. +""" + frontmatter, body = parse_skill_md(content) + + assert frontmatter is not None + assert frontmatter.name == "Test Skill" + assert frontmatter.description == "A test skill for testing" + assert "# Instructions" in body + assert "Use this skill to do testing." in body + + def test_parse_frontmatter_name_only(self): + """Test parsing frontmatter with only required name field.""" + content = """--- +name: Minimal Skill +--- + +Instructions here. +""" + frontmatter, body = parse_skill_md(content) + + assert frontmatter is not None + assert frontmatter.name == "Minimal Skill" + assert frontmatter.description is None + assert "Instructions here." in body + + def test_parse_no_frontmatter(self): + """Test parsing content without frontmatter.""" + content = """# Just Markdown + +No YAML frontmatter here. +""" + frontmatter, body = parse_skill_md(content) + + assert frontmatter is None + assert "# Just Markdown" in body + + def test_parse_empty_frontmatter(self): + """Test parsing empty frontmatter block.""" + content = """--- +--- + +Content after empty frontmatter. +""" + frontmatter, body = parse_skill_md(content) + + # Empty frontmatter should fail validation (no name) + assert frontmatter is None + + def test_parse_frontmatter_missing_name(self): + """Test parsing frontmatter without required name field.""" + content = """--- +description: Description but no name +--- + +Body content. +""" + frontmatter, body = parse_skill_md(content) + + # Should fail validation since name is required + assert frontmatter is None + + def test_parse_frontmatter_name_too_long(self): + """Test that name exceeding 64 characters fails validation.""" + long_name = "A" * 65 # 65 chars, exceeds limit + content = f"""--- +name: {long_name} +--- + +Body content. +""" + frontmatter, body = parse_skill_md(content) + + # Should fail validation due to name length + assert frontmatter is None + + def test_parse_frontmatter_name_at_limit(self): + """Test that name exactly 64 characters is valid.""" + exact_name = "A" * 64 # exactly 64 chars + content = f"""--- +name: {exact_name} +--- + +Body content. +""" + frontmatter, body = parse_skill_md(content) + + assert frontmatter is not None + assert frontmatter.name == exact_name + + def test_parse_frontmatter_description_too_long(self): + """Test that description exceeding 1024 characters fails validation.""" + long_desc = "B" * 1025 # 1025 chars, exceeds limit + content = f"""--- +name: Valid Name +description: {long_desc} +--- + +Body content. +""" + frontmatter, body = parse_skill_md(content) + + # Should fail validation due to description length + assert frontmatter is None + + +class TestSkillFrontmatter: + """Tests for SkillFrontmatter Pydantic model.""" + + def test_valid_frontmatter(self): + """Test creating valid frontmatter.""" + fm = SkillFrontmatter(name="Test", description="A test skill") + assert fm.name == "Test" + assert fm.description == "A test skill" + + def test_frontmatter_name_only(self): + """Test creating frontmatter with name only.""" + fm = SkillFrontmatter(name="NameOnly") + assert fm.name == "NameOnly" + assert fm.description is None + + def test_frontmatter_name_required(self): + """Test that name is required.""" + with pytest.raises(Exception): + SkillFrontmatter(description="No name") + + def test_frontmatter_name_max_length(self): + """Test name max length constraint.""" + # Should work at 64 + fm = SkillFrontmatter(name="A" * 64) + assert len(fm.name) == 64 + + # Should fail at 65 + with pytest.raises(Exception): + SkillFrontmatter(name="A" * 65) + + def test_frontmatter_description_max_length(self): + """Test description max length constraint.""" + # Should work at 1024 + fm = SkillFrontmatter(name="Test", description="B" * 1024) + assert len(fm.description) == 1024 + + # Should fail at 1025 + with pytest.raises(Exception): + SkillFrontmatter(name="Test", description="B" * 1025) + + +class TestValidateSkillFiles: + """Tests for validate_skill_files function.""" + + def test_valid_skill_files(self): + """Test validating files with valid SKILL.md.""" + skill_md_content = b"""--- +name: Test Skill +description: A test +--- + +Instructions here. +""" + files = [("SKILL.md", skill_md_content)] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert errors == [] + assert zip_content is not None + assert frontmatter is not None + assert frontmatter.name == "Test Skill" + assert "Instructions here." in body + + def test_missing_skill_md(self): + """Test validation fails without SKILL.md.""" + files = [("README.md", b"# Readme\nSome content")] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert "SKILL.md is required" in errors[0] + assert zip_content is None + + def test_file_size_limit(self): + """Test validation fails for files exceeding 8MB.""" + # Create content > 8MB + large_content = b"x" * (8 * 1024 * 1024 + 1) # 8MB + 1 byte + files = [ + ("SKILL.md", b"---\nname: Test\n---\nContent"), + ("large_file.bin", large_content), + ] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert any("8MB" in err or "size" in err.lower() for err in errors) + + def test_nested_skill_md(self): + """Test that SKILL.md in nested folder is found.""" + skill_md_content = b"""--- +name: Nested Skill +--- + +Nested instructions. +""" + files = [("subfolder/SKILL.md", skill_md_content)] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert errors == [] + assert frontmatter is not None + assert frontmatter.name == "Nested Skill" + + def test_multiple_files_creates_zip(self): + """Test that multiple files are packed into a valid ZIP.""" + import zipfile + from io import BytesIO + + skill_md = b"""--- +name: Multi File Skill +--- + +Use the helper module. +""" + helper_py = b"def helper(): return 42" + + files = [ + ("SKILL.md", skill_md), + ("helper.py", helper_py), + ] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert errors == [] + assert zip_content is not None + + # Verify it's a valid ZIP + zip_buffer = BytesIO(zip_content) + with zipfile.ZipFile(zip_buffer, "r") as zf: + names = zf.namelist() + assert any("SKILL.md" in n for n in names) + assert any("helper.py" in n for n in names) + + def test_invalid_frontmatter_in_skill_md(self): + """Test validation fails for SKILL.md with invalid frontmatter.""" + # Missing required name + skill_md = b"""--- +description: No name field +--- + +Body content. +""" + files = [("SKILL.md", skill_md)] + + zip_content, frontmatter, body, errors = validate_skill_files(files) + + assert any("frontmatter" in err.lower() or "name" in err.lower() for err in errors)