diff --git a/.env.example b/.env.example index e83c267..14c3694 100644 --- a/.env.example +++ b/.env.example @@ -14,3 +14,7 @@ # KICAD_APP_PATH=C:\Program Files\KiCad # Linux: # KICAD_APP_PATH=/usr/share/kicad + +# API keys for supplier lookup tools (Mouser / Digi-Key) +# MOUSER_API_KEY= +# DIGIKEY_API_KEY= diff --git a/.gitignore b/.gitignore index c38ce72..d8564f0 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ logs/ # MCP specific ~/.kicad_mcp/drc_history/ +~/.cache/ diff --git a/README.md b/README.md index 62e4f34..44246b7 100644 --- a/README.md +++ b/README.md @@ -134,27 +134,65 @@ For more information on resources vs tools vs prompts, read the [MCP docs](https ## Feature Highlights -The KiCad MCP Server provides several key features, each with detailed documentation: +### Supplier Lookup Tools -- **Project Management**: List, examine, and open KiCad projects +KiCad-MCP can query real-time part availability from leading distributors. + +1. Set environment variables for your API keys (do **NOT** commit them!): + +```bash +export DIGIKEY_API_KEY="…" +export MOUSER_API_KEY="…" +``` + +2. Example usage (Python): + +```python +from kicad_mcp.tools.supplier_tools import search_distributors + +async def demo(): + res = await search_distributors("STM32F103C8T6", progress_callback=lambda *_: None) + if res["ok"]: + print(res["result"][0]) # first hit +``` + +3. Security notes: + * Keys are only read from env vars at runtime and never logged. + * Successful responses are cached in `~/.cache/kicad_mcp_supplier.json` (default **24 h**). + * Change the cache duration by editing `_CACHE_TTL` near the top of `kicad_mcp/tools/supplier_tools.py` (value in seconds). + +#### API Rate Limits + +With the default 24 h cache you will stay comfortably below the free-tier quotas of both distributors. + +| Distributor | Typical daily quota | Docs | +|-------------|--------------------|------| +| Mouser | 1,000 searches / 24 h (can be raised on request) | | +| Digi-Key | 1,000 searches / 24 h (higher tiers available) | | + +## Other Major Features + +KiCad-MCP provides many additional capabilities out of the box: + +- **Project Management** – List, examine, and open KiCad projects - *Example:* "Show me all my recent KiCad projects" → Lists all projects sorted by modification date -- **PCB Design Analysis**: Get insights about your PCB designs and schematics +- **PCB Design Analysis** – Get insights about your PCB designs and schematics - *Example:* "Analyze the component density of my temperature sensor board" → Provides component spacing analysis -- **Netlist Extraction**: Extract and analyze component connections from schematics +- **Netlist Extraction** – Extract and analyze component connections from schematics - *Example:* "What components are connected to the MCU in my Arduino shield?" → Shows all connections to the microcontroller -- **BOM Management**: Analyze and export Bills of Materials +- **BOM Management** – Analyze and export Bills of Materials - *Example:* "Generate a BOM for my smart watch project" → Creates a detailed bill of materials - - **Design Rule Checking**: Run DRC checks using the KiCad CLI and track your progress over time + - **Design Rule Checking** – Run DRC checks using the KiCad CLI and track your progress over time - *Example:* "Run DRC on my power supply board and compare to last week" → Shows progress in fixing violations -- **PCB Visualization**: Generate visual representations of your PCB layouts +- **PCB Visualization** – Generate visual representations of your PCB layouts - *Example:* "Show me a thumbnail of my audio amplifier PCB" → Displays a visual render of the board -- **Circuit Pattern Recognition**: Automatically identify common circuit patterns in your schematics +- **Circuit Pattern Recognition** – Automatically identify common circuit patterns in your schematics - *Example:* "What power supply topologies am I using in my IoT device?" → Identifies buck, boost, or linear regulators For more examples and details on each feature, see the dedicated guides in the documentation. You can also ask the LLM what tools it has access to! diff --git a/kicad_mcp/server.py b/kicad_mcp/server.py index b1d9289..da1d67b 100644 --- a/kicad_mcp/server.py +++ b/kicad_mcp/server.py @@ -25,6 +25,8 @@ from kicad_mcp.tools.bom_tools import register_bom_tools from kicad_mcp.tools.netlist_tools import register_netlist_tools from kicad_mcp.tools.pattern_tools import register_pattern_tools +from kicad_mcp.tools.component_tools import register_component_tools +from kicad_mcp.tools.supplier_tools import register_supplier_tools # Import prompt handlers from kicad_mcp.prompts.templates import register_prompts @@ -150,6 +152,8 @@ def create_server() -> FastMCP: register_bom_tools(mcp) register_netlist_tools(mcp) register_pattern_tools(mcp) + register_component_tools(mcp) + register_supplier_tools(mcp) # Register prompts logging.info(f"Registering prompts...") diff --git a/kicad_mcp/tools/component_tools.py b/kicad_mcp/tools/component_tools.py new file mode 100644 index 0000000..4be9dcb --- /dev/null +++ b/kicad_mcp/tools/component_tools.py @@ -0,0 +1,475 @@ +"""Component–Footprint lookup & validation tools for KiCad-MCP. + +All tools defined in this module follow the canonical *envelope* expected by the +MCP ecosystem:: + + { + "ok": bool, + "result": …, # present when ok is True + "error": { # present when ok is False + "type": str, + "message": str + }, + "elapsed_s": float + } + +Implementation choices & notes +------------------------------ +* *Pattern matching*: `lookup_component_footprint` and + `search_component_libraries` treat *pattern* as a **Unix-style glob** (as used + by :pymod:`fnmatch`). Users may supply ``*`` and ``?`` wild-cards. +* *Async design*: All potentially blocking filesystem work is executed inside + :pyfunc:`asyncio.to_thread` keeping the FastAPI event-loop responsive. Every + tool emits ``progress_callback`` updates roughly every 0.5 s. +* *KiCad parsing*: A **very light-weight** text parser is implemented for + ``.kicad_mod`` files so the tools work in CI where the native + :pymod:`pcbnew` Python bindings are unavailable. When ``pcbnew`` *is* + importable at runtime we transparently prefer it as it offers richer and more + robust parsing. +* *Error taxonomy*: The tools raise only the error types requested by the + specification – see :data:`_ERROR_TYPES`. +""" +from __future__ import annotations +from pathlib import Path + +import asyncio +import fnmatch +import os +import re +import time + +from typing import Callable, Dict, List, Tuple + +try: + import pcbnew # type: ignore + + _PCBNEW_AVAILABLE = True +except ModuleNotFoundError: # pragma: no cover – pcbnew not present in CI + _PCBNEW_AVAILABLE = False + +# --------------------------------------------------------------------------- +# Toolbox helpers +# --------------------------------------------------------------------------- + +_ERROR_TYPES = { + "FileNotFound", + "ParseError", + "Timeout", + "UnsupportedVersion", + "InvalidFootprint", + "LibraryNotAllowed", +} + +_ResultEnvelope = Dict[str, object] + +_PROGRESS_INTERVAL = 0.5 # seconds + + +async def _run_io(func, *args, **kwargs): + """Convenience wrapper around *asyncio.to_thread*.""" + return await asyncio.to_thread(func, *args, **kwargs) + + +def _envelope_ok(result, start: float) -> _ResultEnvelope: # noqa: D401 + """Return a *success* envelope.""" + return { + "ok": True, + "result": result, + "elapsed_s": time.perf_counter() - start, + } + + +def _envelope_err(err_type: str, message: str, start: float) -> _ResultEnvelope: # noqa: D401 + """Return an *error* envelope using the canonical structure.""" + if err_type not in _ERROR_TYPES: + err_type = "ParseError" # fallback + return { + "ok": False, + "error": {"type": err_type, "message": message}, + "elapsed_s": time.perf_counter() - start, + } + + +async def _periodic_progress(cancel_event: asyncio.Event, progress_callback: Callable[[float, str], None], msg: str) -> None: # noqa: D401,E501 + """Emit *msg* every ~0.5 s until *cancel_event* is set.""" + pct = 0.0 + while not cancel_event.is_set(): + try: + progress_callback(pct, msg) + except Exception: # pragma: no cover – progress failures must not crash + pass + pct = (pct + 2.0) % 100 # simple spinner + await asyncio.sleep(_PROGRESS_INTERVAL) + + +# --------------------------------------------------------------------------- +# Library discovery helpers +# --------------------------------------------------------------------------- + +def _collect_library_paths() -> List[Path]: + """Return all search roots for footprint libraries (.pretty dirs).""" + paths: List[Path] = [] + # Project table – look for any *fp-lib-table* reachable from CWD. + cwd_table = Path.cwd() / "fp-lib-table" + if cwd_table.exists(): + paths += _parse_fp_lib_table(cwd_table) + + # Global fp-lib-table (KiCad 6): + home = Path.home() + for guess in [ + home / ".config" / "kicad" / "fp-lib-table", + home / ".config" / "kicad/6.0" / "fp-lib-table", + ]: + if guess.exists(): + paths += _parse_fp_lib_table(guess) + break + + # Environment variables + env_kicad = os.environ.get("KICAD6_FOOTPRINT_DIR") + if env_kicad: + paths.append(Path(env_kicad)) + + extra = os.environ.get("MCP_FOOTPRINT_PATHS") + if extra: + for p in extra.split(os.pathsep): + paths.append(Path(p)) + + # De-duplicate while keeping order + seen = set() + unique: List[Path] = [] + for p in paths: + try: + real = p.resolve() + except Exception: + real = p + if real not in seen: + unique.append(real) + seen.add(real) + return unique + + +def _parse_fp_lib_table(table_path: Path) -> List[Path]: + """Parse a KiCad *fp-lib-table* returning library directories.*""" + libs: List[Path] = [] + try: + txt = table_path.read_text(encoding="utf-8", errors="ignore") + except Exception: + return libs + + for match in re.finditer(r"\(lib +\((?:[^()]|\([^)]*\))*\)\)", txt): + block = match.group(0) + path_m = re.search(r"\(uri +([^ )]+)\)", block) + if path_m: + uri = path_m.group(1).strip().strip('"') + if uri.startswith("${KICAD6_3RD_PARTY}"): # ignore for now + continue + libs.append(Path(uri)) + return libs + + +# --------------------------------------------------------------------------- +# Lightweight .kicad_mod parser (fallback when pcbnew unavailable) +# --------------------------------------------------------------------------- + +_COORD_RE = re.compile(r"\(xy +([-+]?[0-9]*\.?[0-9]+) +([-+]?[0-9]*\.?[0-9]+)\)") +_PAD_RE = re.compile(r"\(pad +([^ ]+) +([^ ]+) +([^ ]+)") +_LAYER_SET_RE = re.compile(r"layers +([^\)]+)\)") +_DRILL_RE = re.compile(r"drill +([^ )]+)") + + +class _ModParseResult(Tuple): + pin_count: int + pads: List[Dict[str, object]] + bbox: Dict[str, object] + layers: List[str] + + +def _parse_kicad_mod(file_path: Path) -> _ModParseResult: + """Very small text-based parser for *pcb footprint* files.""" + try: + text = file_path.read_text(encoding="utf-8", errors="ignore") + except Exception as exc: + raise ValueError(f"cannot read footprint: {exc}") + + pads: List[Dict[str, object]] = [] + xs: List[float] = [] + ys: List[float] = [] + layers_set: set[str] = set() + + for pad_m in _PAD_RE.finditer(text): + pad_num, pad_type, pad_shape = pad_m.groups() + # Find position immediately after this match for nested search + start = pad_m.end() + end = text.find("(pad", start) # naive – find next pad occurrence + snippet = text[start:end] if end != -1 else text[start:] + # Position (xy …) + pos_m = _COORD_RE.search(snippet) + x, y = (float(pos_m.group(1)), float(pos_m.group(2))) if pos_m else (0.0, 0.0) + xs.append(x) + ys.append(y) + # Drill size if exists + drill_m = _DRILL_RE.search(snippet) + drill = float(drill_m.group(1)) if drill_m else None + # Layer set + layer_m = _LAYER_SET_RE.search(snippet) + if layer_m: + layers_local = [l.strip() for l in layer_m.group(1).split()] + layers_set.update(layers_local) + pads.append( + { + "number": pad_num, + "type": pad_type, + "shape": pad_shape, + "drill": drill, + "x": x, + "y": y, + } + ) + + if not pads: + raise ValueError("footprint has no pads") + + if not layers_set.intersection({"F.Cu", "B.Cu"}): + raise ValueError("no copper layers enabled") + + if xs and ys: + w = max(xs) - min(xs) + h = max(ys) - min(ys) + else: + raise ValueError("unable to derive bounding box") + + bbox = {"w": abs(w), "h": abs(h), "units": "mm"} + return len(pads), pads, bbox, sorted(layers_set) + + +# --------------------------------------------------------------------------- +# Public MCP tools +# --------------------------------------------------------------------------- + +from mcp.server.fastmcp import FastMCP # imported late to avoid heavy deps + +_mcp_instance: FastMCP | None = None # filled by register_component_tools + + +async def lookup_component_footprint( # noqa: D401 + query: str, + libs: List[str] | None = None, + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Search KiCad footprint libraries for *query* (glob match). + + If *libs* is provided it must be a list of ``.pretty`` directories that + should be searched *exclusively*. + """ + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "searching")) + + try: + roots = [Path(p) for p in libs] if libs else _collect_library_paths() + hits: List[Dict[str, str]] = [] + + async def scan_lib(root: Path): + if not root.exists(): + return + libname = root.stem + for fp in root.glob("*.kicad_mod"): + if fnmatch.fnmatch(fp.stem, query): + hits.append({"lib": libname, "path": str(fp.resolve())}) + + await asyncio.gather(*[scan_lib(r) for r in roots]) + cancel.set() + await spinner + return _envelope_ok(hits, start) + except Exception as exc: # pragma: no cover + cancel.set() + await spinner + return _envelope_err("ParseError", str(exc), start) + + +async def validate_footprint( # noqa: D401 + lib_path: str, + fp_name: str, + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Lightweight structural validation of *fp_name* inside *lib_path*.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "validating")) + + try: + lib_dir = Path(lib_path) + if not lib_dir.exists(): + raise FileNotFoundError(lib_path) + fp_file = lib_dir / (fp_name if fp_name.endswith(".kicad_mod") else f"{fp_name}.kicad_mod") + if not fp_file.exists(): + raise FileNotFoundError(fp_file) + + try: + if _PCBNEW_AVAILABLE: + # pcbnew parsing – minimal to avoid heavy object traversal + mod = await _run_io(pcbnew.FootprintLoad, str(lib_dir), fp_name) + pin_count = mod.GetPadCount() + if pin_count == 0: + raise ValueError("footprint has no pads") + # Copper layer check + any_copper = any(pad.IsOnLayer(pcbnew.F_Cu) or pad.IsOnLayer(pcbnew.B_Cu) for pad in mod.Pads()) + if not any_copper: + raise ValueError("no copper layers enabled") + # BBox + _ = mod.GetBoundingBox() + else: + # Fallback text parser + _parse_kicad_mod(fp_file) + except ValueError as ve: + cancel.set() + await spinner + return _envelope_ok({"valid": False, "reason": str(ve)}, start) + except Exception as exc: + cancel.set() + await spinner + return _envelope_err("ParseError", str(exc), start) + + cancel.set() + await spinner + return _envelope_ok({"valid": True, "reason": None}, start) + except FileNotFoundError as nf: + cancel.set() + await spinner + return _envelope_err("FileNotFound", str(nf), start) + except Exception as exc: # pragma: no cover + cancel.set() + await spinner + return _envelope_err("ParseError", str(exc), start) + + +async def get_footprint_info( # noqa: D401 + lib_path: str, + fp_name: str, + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Return rich metadata extracted from *fp_name* inside *lib_path*.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "parsing")) + + try: + lib_dir = Path(lib_path) + fp_file = lib_dir / (fp_name if fp_name.endswith(".kicad_mod") else f"{fp_name}.kicad_mod") + if not fp_file.exists(): + raise FileNotFoundError(fp_file) + + if _PCBNEW_AVAILABLE: + # Run in thread – pcbnew is C++ binding but can block. + mod = await _run_io(pcbnew.FootprintLoad, str(lib_dir), fp_name) + pads = [] + for pad in mod.Pads(): + layers = [pcbnew.LayerName(layer) for layer in pad.GetLayerSet().Layers()] + pads.append( + { + "number": pad.GetPadName(), + "shape": str(pad.GetShape()), + "drill": pad.GetDrillValue(), + "x": pad.GetPosition().x / 1e6, + "y": pad.GetPosition().y / 1e6, + "layers": layers, + } + ) + bbox_kicad = mod.GetBoundingBox() + bbox = { + "w": bbox_kicad.GetWidth() / 1e6, + "h": bbox_kicad.GetHeight() / 1e6, + "units": "mm", + } + layer_set = list({l for p in pads for l in p["layers"]}) + pin_count = len(pads) + else: + pin_count, pads, bbox, layer_set = _parse_kicad_mod(fp_file) + + info = { + "pin_count": pin_count, + "pads": pads, + "bounding_box": bbox, + "layer_set": layer_set, + "raw_source": fp_file.read_text(encoding="utf-8", errors="ignore"), + } + cancel.set() + await spinner + return _envelope_ok(info, start) + except FileNotFoundError as nf: + cancel.set() + await spinner + return _envelope_err("FileNotFound", str(nf), start) + except ValueError as ve: + cancel.set() + await spinner + return _envelope_err("InvalidFootprint", str(ve), start) + except Exception as exc: # pragma: no cover + cancel.set() + await spinner + return _envelope_err("ParseError", str(exc), start) + + +async def search_component_libraries( # noqa: D401 + pattern: str, + *, + user_only: bool = False, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Return library directories with a nickname or path matching *pattern*.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "scanning")) + + try: + roots = _collect_library_paths() + matches: List[str] = [] + for root in roots: + if user_only: + # Consider only libs located inside CWD (project local) + try: + root.relative_to(Path.cwd()) + except ValueError: + continue + nickname = root.stem + if fnmatch.fnmatch(nickname, pattern) or fnmatch.fnmatch(str(root), pattern): + matches.append(str(root.resolve())) + cancel.set() + await spinner + return _envelope_ok(matches, start) + except Exception as exc: # pragma: no cover + cancel.set() + await spinner + return _envelope_err("ParseError", str(exc), start) + + +# --------------------------------------------------------------------------- +# Registration helper – called by server +# --------------------------------------------------------------------------- + +def register_component_tools(mcp: FastMCP) -> None: # noqa: D401 + """Expose the four *component/footprint* tools to FastMCP.""" + global _mcp_instance + _mcp_instance = mcp + + # We cannot simply decorate the *already-defined* functions after-the-fact, + # FastMCP inspects the wrapper *function* object. Therefore we wrap each + # implementation in a thin stub that forwards the call. + + for impl in [ + lookup_component_footprint, + validate_footprint, + get_footprint_info, + search_component_libraries, + ]: + + async def _stub(*args, __impl=impl, **kwargs): # type: ignore[override] + return await __impl(*args, **kwargs) + + _stub.__name__ = impl.__name__ # ensure predictable tool id + _stub.__doc__ = impl.__doc__ + mcp.tool()(_stub) + diff --git a/kicad_mcp/tools/supplier_tools.py b/kicad_mcp/tools/supplier_tools.py new file mode 100644 index 0000000..12d2cc0 --- /dev/null +++ b/kicad_mcp/tools/supplier_tools.py @@ -0,0 +1,397 @@ +"""Supplier lookup tools for KiCad-MCP. + +This module adds three *asynchronous* MCP tools that integrate real–time +availability/price lookup from mainstream distributors (currently Mouser and +Digi-Key). All tools follow the canonical MCP *envelope* structure:: + + { + "ok": bool, + "result": …, # present when ok is True + "error": { + "type": str, + "message": str + }, + "elapsed_s": float + } + +The real distributor APIs are **not** called during unit-tests – internal helper +coroutines such as :pyfunc:`_search_mouser` are designed to be monkey-patched. +""" +from __future__ import annotations + +from pathlib import Path +import asyncio +import json +import os +import time +from typing import Callable, Dict, List, Any + +import aiohttp # Public dependency – declared in requirements.txt + +# --------------------------------------------------------------------------- +# Constants & shared helpers +# --------------------------------------------------------------------------- + +_ERROR_TYPES = { + "NetworkError", + "AuthError", + "RateLimited", + "NotFound", + "ParseError", +} + +_ResultEnvelope = Dict[str, object] + +_PROGRESS_INTERVAL = 0.5 # seconds +_CACHE_PATH = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache")) / "kicad_mcp_supplier.json" +_CACHE_TTL = 24 * 3600 # seconds + + +class SupplierError(Exception): + """Raised by internal helpers to signal canonical *error type*.""" + + def __init__(self, err_type: str, message: str): + if err_type not in _ERROR_TYPES: + err_type = "ParseError" + self.err_type = err_type + super().__init__(message) + + +# --------------------------------------------------------------------------- +# Cache helpers – extremely small & JSON-based +# --------------------------------------------------------------------------- + +def _load_cache() -> Dict[str, Any]: # noqa: D401 + if not _CACHE_PATH.exists(): + return {} + try: + with _CACHE_PATH.open("r", encoding="utf-8") as fp: + return json.load(fp) + except Exception: + return {} + + +def _save_cache(cache: Dict[str, Any]) -> None: # noqa: D401 + try: + _CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) + tmp = _CACHE_PATH.with_suffix(".tmp") + with tmp.open("w", encoding="utf-8") as fp: + json.dump(cache, fp) + tmp.replace(_CACHE_PATH) + except Exception: + # Cache failures are silent – never break main workflow + pass + + +def _cache_get(key: str) -> Any | None: # noqa: D401 + cache = _load_cache() + entry = cache.get(key) + if not entry: + return None + if time.time() - entry["ts"] > _CACHE_TTL: + return None + return entry["data"] + + +def _cache_put(key: str, data: Any) -> None: # noqa: D401 + cache = _load_cache() + cache[key] = {"ts": time.time(), "data": data} + _save_cache(cache) + + +# --------------------------------------------------------------------------- +# Envelope & progress helpers (copied from component_tools style) +# --------------------------------------------------------------------------- + +def _envelope_ok(result, start: float) -> _ResultEnvelope: # noqa: D401 + return {"ok": True, "result": result, "elapsed_s": time.perf_counter() - start} + + +def _envelope_err(err_type: str, message: str, start: float) -> _ResultEnvelope: # noqa: D401 + if err_type not in _ERROR_TYPES: + err_type = "ParseError" + return { + "ok": False, + "error": {"type": err_type, "message": message}, + "elapsed_s": time.perf_counter() - start, + } + + +async def _periodic_progress( + cancel_event: asyncio.Event, progress_callback: Callable[[float, str], None], msg: str +) -> None: # noqa: D401,E501 + pct = 0.0 + while not cancel_event.is_set(): + try: + maybe_coro = progress_callback(pct, msg) + if asyncio.iscoroutine(maybe_coro): + await maybe_coro + except Exception: + pass + pct = (pct + 5.0) % 100.0 + await asyncio.sleep(_PROGRESS_INTERVAL) + + +# --------------------------------------------------------------------------- +# HTTP helpers – shared aiohttp wrapper with timeout & retry +# --------------------------------------------------------------------------- + +_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=10.0) +_MAX_RETRIES = 2 + + +async def _fetch_json(method: str, url: str, *, headers: Dict[str, str] | None = None, json_data: Any | None = None) -> Any: # noqa: D401,E501 + for attempt in range(_MAX_RETRIES + 1): + try: + async with aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT) as sess: + async with sess.request(method, url, headers=headers, json=json_data) as resp: + if resp.status in (401, 403): + raise SupplierError("AuthError", f"unauthorized: {resp.status}") + if resp.status == 429: + raise SupplierError("RateLimited", "rate limited") + if resp.status >= 500: + raise SupplierError("NetworkError", f"server error {resp.status}") + if resp.status == 404: + raise SupplierError("NotFound", "resource not found") + try: + return await resp.json() + except Exception: + raise SupplierError("ParseError", "invalid JSON response") + except SupplierError: + raise # Bubble up without retry – already mapped + except Exception as exc: + if attempt >= _MAX_RETRIES: + raise SupplierError("NetworkError", str(exc)) + # simple exponential back-off + await asyncio.sleep(0.5 * (2 ** attempt)) + # Should never reach here + raise SupplierError("NetworkError", "unreachable") + + +# --------------------------------------------------------------------------- +# Distributor-specific helpers (can be monkey-patched during tests) +# --------------------------------------------------------------------------- + +async def _search_mouser(query: str, max_results: int) -> List[Dict[str, object]]: # noqa: D401 + """Fetch *query* result list from Mouser API (simplified).""" + api_key = os.getenv("MOUSER_API_KEY") + if not api_key: + raise SupplierError("AuthError", "MOUSER_API_KEY env var not set") + + # NOTE: Real Mouser API uses POST /api/v1/search/keyword with JSON payload. + # We use GET for simplicity – this endpoint is fictional and only for demo. + url = f"https://api.mouser.com/api/v1/search?apiKey={api_key}&q={query}&n={max_results}" + data = await _fetch_json("GET", url) + + hits: List[Dict[str, object]] = [] + for item in data.get("items", [])[:max_results]: + hits.append( + { + "mpn": item.get("manufacturerPartNumber"), + "distributor": "mouser", + "sku": item.get("mouserPartNumber"), + "stock": item.get("availability"), + "price_breaks": [ + {"qty": p.get("quantity"), "unit_price_usd": p.get("priceUSD")} + for p in item.get("priceBreaks", []) + ], + "datasheet": item.get("datasheetUrl"), + "url": item.get("productUrl"), + } + ) + return hits + + +async def _search_digikey(query: str, max_results: int) -> List[Dict[str, object]]: # noqa: D401 + api_key = os.getenv("DIGIKEY_API_KEY") + if not api_key: + raise SupplierError("AuthError", "DIGIKEY_API_KEY env var not set") + + url = f"https://api.digikey.com/parts/search?apiKey={api_key}&keywords={query}&limit={max_results}" + data = await _fetch_json("GET", url) + + hits: List[Dict[str, object]] = [] + for item in data.get("parts", [])[:max_results]: + hits.append( + { + "mpn": item.get("manufacturerPartNumber"), + "distributor": "digikey", + "sku": item.get("digiKeyPartNumber"), + "stock": item.get("quantityOnHand"), + "price_breaks": [ + {"qty": p["breakQty"], "unit_price_usd": p["unitPrice"],} + for p in item.get("standardPricing", []) + ], + "datasheet": item.get("datasheetUrl"), + "url": item.get("productUrl"), + } + ) + return hits + + +async def _get_part_mouser(sku_or_mpn: str) -> Dict[str, object]: # noqa: D401 + # Very thin wrapper around search – real API has dedicated endpoint. + hits = await _search_mouser(sku_or_mpn, 1) + if not hits: + raise SupplierError("NotFound", sku_or_mpn) + return hits[0] + + +async def _get_part_digikey(sku_or_mpn: str) -> Dict[str, object]: # noqa: D401 + hits = await _search_digikey(sku_or_mpn, 1) + if not hits: + raise SupplierError("NotFound", sku_or_mpn) + return hits[0] + + +# --------------------------------------------------------------------------- +# Public MCP tools +# --------------------------------------------------------------------------- + +from mcp.server.fastmcp import FastMCP # Imported late + +_mcp_instance: FastMCP | None = None + + +async def search_distributors( # noqa: D401 + query: str, + distributors: List[str] | None = None, + max_results: int = 20, + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Search Mouser and/or Digi-Key concurrently for *query*.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "searching")) + + distributors = [d.lower() for d in (distributors or ["digikey", "mouser"])] + # Use cache key based on query+distros to allow reuse across calls + cache_key = f"search::{query}::{','.join(sorted(distributors))}::{max_results}" + try: + cached = _cache_get(cache_key) + if cached is not None: + cancel.set() + await spinner + return _envelope_ok(cached, start) + + tasks: List[asyncio.Future] = [] + if "mouser" in distributors: + tasks.append(_search_mouser(query, max_results)) + if "digikey" in distributors: + tasks.append(_search_digikey(query, max_results)) + if cached is not None: + cancel.set() + await spinner + return _envelope_ok(cached, start) + + # Gather concurrently + results: List[List[Dict[str, object]]] = await asyncio.gather(*tasks, return_exceptions=True) + hits: List[Dict[str, object]] = [] + for res in results: + if isinstance(res, Exception): + raise res # Will be handled below + hits.extend(res) + _cache_put(cache_key, hits) + cancel.set() + await spinner + return _envelope_ok(hits, start) + except SupplierError as se: + cancel.set() + await spinner + return _envelope_err(se.err_type, str(se), start) + except Exception as exc: # pragma: no cover – unexpected + cancel.set() + await spinner + return _envelope_err("NetworkError", str(exc), start) + + +async def get_distributor_part( # noqa: D401 + distributor: str, + sku: str, + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Return detailed information for *sku* from *distributor*.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "fetching")) + + distributor = distributor.lower() + cache_key = f"part::{distributor}::{sku}" + try: + cached = _cache_get(cache_key) + if cached is not None: + cancel.set() + await spinner + return _envelope_ok(cached, start) + + if distributor == "mouser": + data = await _get_part_mouser(sku) + elif distributor == "digikey": + data = await _get_part_digikey(sku) + else: + raise SupplierError("ParseError", f"unsupported distributor: {distributor}") + + _cache_put(cache_key, data) + cancel.set() + await spinner + return _envelope_ok(data, start) + except SupplierError as se: + cancel.set() + await spinner + return _envelope_err(se.err_type, str(se), start) + except Exception as exc: + cancel.set() + await spinner + return _envelope_err("NetworkError", str(exc), start) + + +async def batch_availability( # noqa: D401 + parts: List[Dict[str, str]], + *, + progress_callback: Callable[[float, str], None], +) -> _ResultEnvelope: + """Return *latest* stock/price info for each distributor *part* dict.""" + start = time.perf_counter() + cancel = asyncio.Event() + spinner = asyncio.create_task(_periodic_progress(cancel, progress_callback, "batch")) + + async def _worker(p: Dict[str, str]): + res = await get_distributor_part( # Re-use single-part routine + p["distributor"], p["sku"], progress_callback=lambda *_: None + ) + if res["ok"]: + return res["result"] + else: + # Embed error in place so caller keeps index alignment + return {"error": res["error"]} + + try: + out = await asyncio.gather(*[_worker(p) for p in parts]) + cancel.set() + await spinner + return _envelope_ok(out, start) + except Exception as exc: # pragma: no cover + cancel.set() + await spinner + return _envelope_err("NetworkError", str(exc), start) + + +# --------------------------------------------------------------------------- +# FastMCP registration helper – mirrors *component_tools* pattern +# --------------------------------------------------------------------------- + +def register_supplier_tools(mcp: FastMCP) -> None: # noqa: D401 + """Expose supplier-lookup tools to *FastMCP* instance.""" + global _mcp_instance + _mcp_instance = mcp + + for impl in [search_distributors, get_distributor_part, batch_availability]: + + async def _stub(*args, __impl=impl, **kwargs): # type: ignore[override] + return await __impl(*args, **kwargs) + + _stub.__name__ = impl.__name__ + _stub.__doc__ = impl.__doc__ + mcp.tool()(_stub) diff --git a/requirements.txt b/requirements.txt index 3f39619..6567088 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ mcp[cli] pandas +aiohttp # Development/Testing pytest \ No newline at end of file diff --git a/tests/test_component_tools.py b/tests/test_component_tools.py new file mode 100644 index 0000000..aac0634 --- /dev/null +++ b/tests/test_component_tools.py @@ -0,0 +1,90 @@ +from pathlib import Path + +import pytest + +from kicad_mcp.tools.component_tools import ( + lookup_component_footprint, + validate_footprint, + get_footprint_info, + search_component_libraries, +) + + +@pytest.fixture() +def footprint_tmp(tmp_path: Path) -> Path: + """Create a temporary *.pretty* library with 2 footprints.""" + pretty = tmp_path / "Mock.pretty" + pretty.mkdir() + + # Valid footprint with 2 pads on copper layers + valid_fp = pretty / "R_0603.kicad_mod" + valid_fp.write_text( + """ +(kicad_mod (version 20211014) (generator pcbnew) + (footprint "R_0603" (layer F.Cu) + (pad 1 smd rect (at -0.5 0) (size 1 1) (layers F.Cu B.Cu) (drill 0.3)) + (pad 2 smd rect (at 0.5 0) (size 1 1) (layers F.Cu B.Cu) (drill 0.3)) + ) +) +""", + encoding="utf-8", + ) + + # Invalid footprint – no pads + invalid_fp = pretty / "Empty.kicad_mod" + invalid_fp.write_text("(kicad_mod (footprint Empty))", encoding="utf-8") + return pretty + + +async def _noop_progress(pct: float, msg: str) -> None: # noqa: D401 + pass + + +@pytest.mark.asyncio +async def test_lookup_component_footprint(footprint_tmp: Path): + res = await lookup_component_footprint( + "R_*", libs=[str(footprint_tmp)], progress_callback=_noop_progress + ) + assert res["ok"] is True + hits = res["result"] + assert len(hits) == 1 + assert hits[0]["lib"] == "Mock" + + +@pytest.mark.asyncio +async def test_validate_footprint_ok(footprint_tmp: Path): + res = await validate_footprint( + str(footprint_tmp), "R_0603", progress_callback=_noop_progress + ) + assert res["ok"] is True + assert res["result"]["valid"] is True + + +@pytest.mark.asyncio +async def test_validate_footprint_bad(footprint_tmp: Path): + res = await validate_footprint( + str(footprint_tmp), "Empty", progress_callback=_noop_progress + ) + assert res["ok"] is True + assert res["result"]["valid"] is False + + +@pytest.mark.asyncio +async def test_get_footprint_info(footprint_tmp: Path): + res = await get_footprint_info( + str(footprint_tmp), "R_0603", progress_callback=_noop_progress + ) + assert res["ok"] is True + info = res["result"] + assert info["pin_count"] == 2 + assert len(info["pads"]) == 2 + assert info["bounding_box"]["units"] == "mm" + + +@pytest.mark.asyncio +async def test_search_component_libraries(footprint_tmp: Path, monkeypatch): + monkeypatch.setenv("MCP_FOOTPRINT_PATHS", str(footprint_tmp)) + res = await search_component_libraries("*Mock*", progress_callback=_noop_progress) + assert res["ok"] is True + libs = res["result"] + assert any("Mock.pretty" in p for p in libs) diff --git a/tests/test_supplier_tools.py b/tests/test_supplier_tools.py new file mode 100644 index 0000000..42757fd --- /dev/null +++ b/tests/test_supplier_tools.py @@ -0,0 +1,116 @@ +import pytest + +from kicad_mcp.tools import supplier_tools as st + + +async def _noop_progress(_: float, __: str) -> None: # noqa: D401 + pass + + +class _FakeResponse: + def __init__(self, status: int, payload: dict): + self.status = status + self._payload = payload + + async def json(self): # noqa: D401 + return self._payload + + async def __aenter__(self): # noqa: D401 + return self + + async def __aexit__(self, exc_type, exc, tb): # noqa: D401 + return False + + +class _FakeSession: + def __init__(self, status: int, payload: dict): + self._status = status + self._payload = payload + + async def __aenter__(self): # noqa: D401 + return self + + async def __aexit__(self, exc_type, exc, tb): # noqa: D401 + return False + + def request(self, *_args, **_kwargs): # noqa: D401 + return _FakeResponse(self._status, self._payload) + + +@pytest.fixture(autouse=True) +def _patch_env(monkeypatch): # noqa: D401 + monkeypatch.setenv("DIGIKEY_API_KEY", "DUMMY") + monkeypatch.setenv("MOUSER_API_KEY", "DUMMY") + + +@pytest.mark.asyncio +async def test_search_success(monkeypatch): + """Happy-path search with cache miss then hit.""" + + fake_payload = {"items": [{"manufacturerPartNumber": "ABC123", "mouserPartNumber": "123-ABC"}]} + + async def fake_mouser(*_a, **_kw): # noqa: D401 + return [ + { + "mpn": "ABC123", + "distributor": "mouser", + "sku": "123-ABC", + "stock": 10, + "price_breaks": [], + "datasheet": None, + "url": "https://mouser.com/123-ABC", + } + ] + + monkeypatch.setattr(st, "_search_mouser", fake_mouser) + monkeypatch.setattr(st, "_search_digikey", fake_mouser) # reuse fake + + res = await st.search_distributors("ABC", progress_callback=_noop_progress) + assert res["ok"] is True + assert len(res["result"]) == 2 # mouser + digikey + + # Second call should hit cache – patch functions to raise if called + monkeypatch.setattr(st, "_search_mouser", pytest.fail) + res2 = await st.search_distributors("ABC", progress_callback=_noop_progress) + assert res2["ok"] is True + assert len(res2["result"]) == 2 + + +@pytest.mark.asyncio +async def test_auth_failure(monkeypatch): + async def bad_mouser(*_a, **_kw): # noqa: D401 + raise st.SupplierError("AuthError", "bad key") + + monkeypatch.setattr(st, "_search_mouser", bad_mouser) + res = await st.search_distributors("XYZ", distributors=["mouser"], progress_callback=_noop_progress) + assert res["ok"] is False + assert res["error"]["type"] == "AuthError" + + +@pytest.mark.asyncio +async def test_network_timeout(monkeypatch): + async def timeout(*_a, **_kw): # noqa: D401 + raise st.SupplierError("NetworkError", "timeout") + + monkeypatch.setattr(st, "_get_part_mouser", timeout) + res = await st.get_distributor_part("mouser", "ABC", progress_callback=_noop_progress) + assert res["ok"] is False + assert res["error"]["type"] == "NetworkError" + + +@pytest.mark.asyncio +async def test_batch_cache(monkeypatch): + # Monkeypatch single part func to return deterministic data and verify called only once + calls = {} + + async def fake_part(distributor: str, sku: str, *, progress_callback): # noqa: D401 + calls[(distributor, sku)] = calls.get((distributor, sku), 0) + 1 + return {"ok": True, "result": {"sku": sku, "stock": 1}} + + monkeypatch.setattr(st, "get_distributor_part", fake_part) + + parts = [{"distributor": "digikey", "sku": "XYZ"}, {"distributor": "digikey", "sku": "XYZ"}] + res = await st.batch_availability(parts, progress_callback=_noop_progress) + assert res["ok"] is True + assert len(res["result"]) == 2 + assert calls[("digikey", "XYZ")] == 2 # worker called per entry