diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f50e7fcaa..da1a6952b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -165,13 +165,12 @@ jobs: echo "✓ Datamodel generation complete" - name: Verify datamodel was generated + working-directory: config run: | echo "Checking for datamodel.py..." - ls -lh config/datamodel.py || echo "❌ datamodel.py not found!" + ls -lh datamodel.py || { echo "❌ datamodel.py not found!"; exit 1; } echo "Attempting import test..." - cd config && python -c "from datamodel import LlamaFarmConfig; print('✓ Direct import successful')" || echo "❌ Direct import failed!" - echo "Attempting module import test..." - python -c "from config.datamodel import LlamaFarmConfig; print('✓ Module import successful')" || echo "❌ Module import failed!" + uv run python -c "from datamodel import LlamaFarmConfig; print('✓ Import successful')" - name: Check if component has tests id: check-tests @@ -225,9 +224,19 @@ jobs: continue-on-error: false - name: Set up Ollama - uses: pydantic/ollama-action@v3 - with: - model: nomic-embed-text + env: + # Pin install script to v0.19.0 release with checksum verification + OLLAMA_INSTALL_URL: https://github.com/ollama/ollama/releases/download/v0.19.0/install.sh + OLLAMA_INSTALL_SHA256: 25f64b810b947145095956533e1bdf56eacea2673c55a7e586be4515fc882c9f + run: | + sudo apt-get install -y --no-install-recommends zstd + curl -fsSL "$OLLAMA_INSTALL_URL" -o install-ollama.sh + echo "${OLLAMA_INSTALL_SHA256} install-ollama.sh" | sha256sum -c - + sh install-ollama.sh + rm install-ollama.sh + ollama serve & + sleep 3 + ollama pull nomic-embed-text - name: Run tests if: steps.check-tests.outputs.has_tests == 'true' diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0f6b6cc2c..fe625e8a0 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -23,7 +23,7 @@ jobs: packages: write strategy: matrix: - service: [designer, server, rag, runtime] + service: [designer, server, rag, runtime, edge-runtime, edge-runtime-lite] include: - service: designer context: ./designer @@ -41,6 +41,14 @@ jobs: context: ./ dockerfile: ./runtimes/universal/Dockerfile description: "LlamaFarm Universal Runtime - Universal Runtime for all models" + - service: edge-runtime + context: ./ + dockerfile: ./runtimes/edge/Dockerfile + description: "LlamaFarm Edge Runtime - Lightweight runtime for edge/drone deployment" + - service: edge-runtime-lite + context: ./ + dockerfile: ./runtimes/edge/Dockerfile + description: "LlamaFarm Edge Runtime (Lite) - Language-only runtime without vision deps" steps: - name: Checkout repository @@ -109,6 +117,7 @@ jobs: build-args: | GIT_SHA=${{ github.sha }} PYTORCH_VARIANT=${{ matrix.service == 'runtime' && github.event_name == 'pull_request' && 'cpu' || '' }} + ENABLE_VISION=${{ matrix.service == 'edge-runtime-lite' && 'false' || '' }} - name: Upload AMD64 image artifact (PR only) if: github.event_name == 'pull_request' @@ -127,7 +136,7 @@ jobs: packages: write strategy: matrix: - service: [designer, server, rag, runtime] + service: [designer, server, rag, runtime, edge-runtime, edge-runtime-lite] include: - service: designer context: ./designer @@ -145,6 +154,14 @@ jobs: context: ./ dockerfile: ./runtimes/universal/Dockerfile description: "LlamaFarm Universal Runtime - model serving for GGUF and Transformers" + - service: edge-runtime + context: ./ + dockerfile: ./runtimes/edge/Dockerfile + description: "LlamaFarm Edge Runtime - Lightweight runtime for edge/drone deployment" + - service: edge-runtime-lite + context: ./ + dockerfile: ./runtimes/edge/Dockerfile + description: "LlamaFarm Edge Runtime (Lite) - Language-only runtime without vision deps" steps: - name: Checkout repository @@ -195,6 +212,7 @@ jobs: outputs: ${{ github.event_name == 'pull_request' && format('type=docker,dest={0}/{1}-arm64.tar', runner.temp, matrix.service) || '' }} build-args: | GIT_SHA=${{ github.sha }} + ENABLE_VISION=${{ matrix.service == 'edge-runtime-lite' && 'false' || '' }} - name: Upload ARM64 image artifact (PR only) if: github.event_name == 'pull_request' @@ -214,7 +232,7 @@ jobs: packages: write strategy: matrix: - service: [designer, server, rag] + service: [designer, server, rag, runtime, edge-runtime, edge-runtime-lite] steps: - name: Checkout repository @@ -368,11 +386,11 @@ jobs: IMAGE_TAG: pr-${{ github.event.number }} run: | # Tag the loaded images with the expected format for docker-compose - SERVICES=(designer server rag runtime) + SERVICES=(designer server rag runtime edge-runtime edge-runtime-lite) for SERVICE in "${SERVICES[@]}"; do # Find the loaded images for this service - AMD64_IMAGE=$(docker images --format "table {{.Repository}}:{{.Tag}}" | grep "$SERVICE" | grep "amd64" | head -1 | tr -d ' ') + AMD64_IMAGE=$(docker images --format "table {{.Repository}}:{{.Tag}}" | grep "/${SERVICE}:" | grep "amd64" | head -1 | tr -d ' ') if [ -n "$AMD64_IMAGE" ]; then # Tag for docker-compose (use AMD64 for testing) @@ -465,6 +483,8 @@ jobs: docker compose -f docker-compose.yml logs --tail=50 designer docker compose -f docker-compose.yml logs --tail=50 rag docker compose -f docker-compose.yml logs --tail=50 runtime + docker compose -f docker-compose.yml logs --tail=50 edge-runtime + docker compose -f docker-compose.yml logs --tail=50 edge-runtime-lite - name: Show logs on failure if: failure() @@ -479,7 +499,7 @@ jobs: docker compose -f docker-compose.yml logs --tail=200 || true echo "" echo "=== Individual service logs ===" - for service in server designer rag runtime; do + for service in server designer rag runtime edge-runtime edge-runtime-lite; do echo "--- Logs for $service ---" docker compose -f docker-compose.yml logs --tail=100 "$service" || true echo "" @@ -504,14 +524,22 @@ jobs: security-events: write strategy: matrix: - service: [designer, server, rag, runtime] + service: [designer, server, rag, runtime, edge-runtime, edge-runtime-lite] steps: - name: Checkout repository uses: actions/checkout@v4 + + - name: Determine image tag + id: tag + run: | + # Use branch name as tag (matches create-manifest metadata) + TAG="${GITHUB_REF_NAME}" + echo "image=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/${{ matrix.service }}:${TAG}" >> "$GITHUB_OUTPUT" + - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master with: - image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/${{ matrix.service }}:latest + image-ref: ${{ steps.tag.outputs.image }} format: "sarif" output: "trivy-results-${{ matrix.service }}.sarif" diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index f2db564da..571a09688 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -57,7 +57,7 @@ jobs: echo "${{ secrets.GITHUB_TOKEN }}" | docker login "${REGISTRY}" -u "${{ github.actor }}" --password-stdin # Services to retag - SERVICES=(designer server rag runtime) + SERVICES=(designer server rag runtime edge-runtime) # Wait for all Docker images to be available echo "Waiting for Docker images to be built and pushed..." diff --git a/cli/cmd/orchestrator/python_env.go b/cli/cmd/orchestrator/python_env.go index 06b122b9e..814546ad2 100644 --- a/cli/cmd/orchestrator/python_env.go +++ b/cli/cmd/orchestrator/python_env.go @@ -142,24 +142,30 @@ func (m *PythonEnvManager) getEnv() []string { // Start with the current environment env := os.Environ() - // Filter out Python-related environment variables that could interfere - // with UV's managed Python environment + // Filter out environment variables that could interfere with UV's managed + // Python environment or cause incorrect package resolution. + // UV index vars are stripped here so that only services that explicitly + // declare them in their Env map (e.g. universal-runtime) will have them. + // This prevents the PyTorch CPU index from leaking into server/rag, where + // it can cause install failures (e.g. markupsafe with only cp314 wheels). filteredEnv := make([]string, 0, len(env)) pythonEnvVars := map[string]bool{ - "VIRTUAL_ENV": true, - "PYTHONHOME": true, - "PYTHONPATH": true, - "PYTHONSTARTUP": true, - "PYTHONEXECUTABLE": true, - "PYTHONUSERBASE": true, - "CONDA_DEFAULT_ENV": true, - "CONDA_PREFIX": true, + "VIRTUAL_ENV": true, + "PYTHONHOME": true, + "PYTHONPATH": true, + "PYTHONSTARTUP": true, + "PYTHONEXECUTABLE": true, + "PYTHONUSERBASE": true, + "CONDA_DEFAULT_ENV": true, + "CONDA_PREFIX": true, "CONDA_PYTHON_EXE": true, - "PYENV_VERSION": true, - "PYENV_VIRTUAL_ENV": true, - "PIPENV_ACTIVE": true, - "POETRY_ACTIVE": true, - "PDM_PYTHON": true, + "PYENV_VERSION": true, + "PYENV_VIRTUAL_ENV": true, + "PIPENV_ACTIVE": true, + "POETRY_ACTIVE": true, + "PDM_PYTHON": true, + "UV_EXTRA_INDEX_URL": true, + "UV_INDEX_STRATEGY": true, } for _, e := range env { diff --git a/cli/cmd/orchestrator/services.go b/cli/cmd/orchestrator/services.go index bc154d9f1..b2c6c95dc 100644 --- a/cli/cmd/orchestrator/services.go +++ b/cli/cmd/orchestrator/services.go @@ -150,8 +150,8 @@ var ServiceGraph = map[string]*ServiceDefinition{ "LLAMAFARM_GGUF_FORCE_CPU": "", // Set to "1" to force CPU for GGUF inference (avoids Metal SIGSEGV in CI) "HF_TOKEN": "", // In CI environments, use CPU-only PyTorch to avoid downloading 3GB+ of CUDA packages - "UV_EXTRA_INDEX_URL": "${UV_EXTRA_INDEX_URL}", - "UV_INDEX_STRATEGY": "", // Inherit from parent env (e.g. unsafe-best-match in CI) + "UV_EXTRA_INDEX_URL": "${UV_EXTRA_INDEX_URL}", + "UV_INDEX_STRATEGY": "${UV_INDEX_STRATEGY}", }, HealthComponent: "universal-runtime", HardwarePackages: []HardwarePackageSpec{ diff --git a/common/llamafarm_common/__init__.py b/common/llamafarm_common/__init__.py index 57960d165..c421fe1f8 100644 --- a/common/llamafarm_common/__init__.py +++ b/common/llamafarm_common/__init__.py @@ -17,6 +17,14 @@ select_gguf_file_with_logging, ) +# Submodules also importable as llamafarm_common.safe_home, etc. +# Kept as submodule imports to avoid adding their deps to the top-level namespace. +# Usage: +# from llamafarm_common.safe_home import safe_home, get_data_dir +# from llamafarm_common.device import get_optimal_device, get_device_info +# from llamafarm_common.model_cache import ModelCache +# from llamafarm_common.model_format import detect_model_format + __all__ = [ "GGUF_QUANTIZATION_PREFERENCE_ORDER", "get_gguf_file_path", diff --git a/common/llamafarm_common/device.py b/common/llamafarm_common/device.py new file mode 100644 index 000000000..3ab64ea29 --- /dev/null +++ b/common/llamafarm_common/device.py @@ -0,0 +1,210 @@ +""" +Device detection and optimization utilities. + +PyTorch is optional - this module provides fallback behavior for GGUF-only +deployments where torch is not installed. llama.cpp has its own GPU detection +independent of PyTorch. +""" + +from __future__ import annotations + +import logging +import platform +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch as torch_type + +logger = logging.getLogger(__name__) + +# Cached torch module reference (lazy loaded) +_torch: torch_type | None = None +_torch_available: bool | None = None + +# Cached device detection result +_optimal_device: str | None = None + + +def _get_torch() -> torch_type | None: + """Lazy-load torch module. Returns None if not installed.""" + global _torch, _torch_available + + if _torch_available is None: + try: + import torch + + _torch = torch + _torch_available = True + logger.debug(f"PyTorch {torch.__version__} loaded successfully") + except ImportError: + _torch = None + _torch_available = False + logger.info("PyTorch not installed - encoder models will not be available") + + return _torch + + +def is_torch_available() -> bool: + """Check if PyTorch is available without importing it.""" + _get_torch() + return _torch_available or False + + +def get_optimal_device() -> str: + """ + Detect the optimal device for the current platform. + + Results are cached so detection (and its log messages) only runs once. + + Returns: + str: Device name ("cuda", "mps", or "cpu") + + Note: + If PyTorch is not installed, always returns "cpu". + This allows GGUF models to still use GPU via llama.cpp's own detection. + """ + global _optimal_device + if _optimal_device is not None: + return _optimal_device + + _optimal_device = _detect_device() + return _optimal_device + + +def _detect_device() -> str: + """Run device detection once (called by get_optimal_device).""" + import os + + # Allow forcing CPU via environment variable + force_cpu = os.environ.get("TRANSFORMERS_FORCE_CPU", "").lower() in ( + "1", + "true", + "yes", + ) + if force_cpu: + logger.info("Forcing CPU device (TRANSFORMERS_FORCE_CPU=1)") + return "cpu" + + # Try to use PyTorch for device detection + torch = _get_torch() + if torch is None: + logger.info("PyTorch not available - using CPU for encoder models") + return "cpu" + + # Check for CUDA + if torch.cuda.is_available(): + logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}") + return "cuda" + + # Check for MPS (Apple Silicon) + # Note: MPS has a 4GB temporary buffer limit which can cause issues with some models + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + # Check if user wants to skip MPS due to known limitations + skip_mps = os.environ.get("TRANSFORMERS_SKIP_MPS", "").lower() in ( + "1", + "true", + "yes", + ) + if skip_mps: + logger.info("Skipping MPS (TRANSFORMERS_SKIP_MPS=1), using CPU") + return "cpu" + logger.info("MPS (Apple Silicon) available") + logger.warning( + "MPS has a 4GB temporary buffer limit. Set TRANSFORMERS_SKIP_MPS=1 to use CPU if you encounter errors." + ) + return "mps" + + # Fallback to CPU + logger.info("Using CPU (no GPU acceleration)") + return "cpu" + + +def get_device_info() -> dict: + """ + Get detailed device information. + + Returns: + dict: Device information including platform, acceleration, memory + """ + device = get_optimal_device() + torch = _get_torch() + + info = { + "device": device, + "platform": platform.system(), + "python_version": platform.python_version(), + "torch_version": torch.__version__ if torch else "not installed", + "torch_available": torch is not None, + } + + if torch is not None: + if device == "cuda": + gpu_count = torch.cuda.device_count() + # Primary GPU info (backward compatible) + free_0, total_0 = torch.cuda.mem_get_info(0) + info.update( + { + "gpu_name": torch.cuda.get_device_name(0), + "gpu_memory_total": total_0, + "gpu_memory_free": free_0, + "gpu_memory_allocated": torch.cuda.memory_allocated(0), + "gpu_count": gpu_count, + } + ) + # Per-GPU details for multi-GPU systems + if gpu_count > 1: + gpus = [] + for i in range(gpu_count): + free, total = torch.cuda.mem_get_info(i) + gpus.append( + { + "index": i, + "name": torch.cuda.get_device_name(i), + "memory_total": total, + "memory_free": free, + "memory_allocated": torch.cuda.memory_allocated(i), + } + ) + info["gpus"] = gpus + elif device == "mps": + info.update( + { + "gpu_name": "Apple Silicon (MPS)", + "architecture": platform.machine(), + } + ) + + return info + + +def get_gguf_gpu_layers() -> int: + """ + Get the number of GPU layers to use for GGUF models. + + IMPORTANT: llama.cpp has its own GPU detection (CUDA, Metal, Vulkan, etc.) + that is independent of PyTorch. We should always try to use GPU layers (-1) + and let llama.cpp fall back to CPU if no GPU backend is available. + This allows users with CPU-only PyTorch but GPU llama.cpp to get acceleration. + + Returns: + int: Number of GPU layers (-1 for all layers on GPU, 0 for CPU only) + """ + import os + + force_cpu = os.environ.get("LLAMAFARM_GGUF_FORCE_CPU", "").lower() in ( + "1", + "true", + "yes", + ) + + if force_cpu: + logger.info("Configuring for CPU-only inference (LLAMAFARM_GGUF_FORCE_CPU=1)") + return 0 + + # Use all layers on GPU - llama.cpp will use whatever backend is available + # (CUDA, Metal, Vulkan, etc.) and fall back to CPU if none are available + logger.info( + "Configuring for GPU acceleration (all layers on GPU, llama.cpp will " + "auto-detect available backends)" + ) + return -1 diff --git a/common/llamafarm_common/model_cache.py b/common/llamafarm_common/model_cache.py new file mode 100644 index 000000000..0e7b832f3 --- /dev/null +++ b/common/llamafarm_common/model_cache.py @@ -0,0 +1,188 @@ +"""TTL-based model cache using cachetools. + +Provides a cache that: +- Automatically tracks last access time +- Refreshes TTL on access (not just on write) +- Supports async cleanup callbacks before expiration +""" + +import time +from collections.abc import Iterator +from typing import Generic, TypeVar + +from cachetools import TTLCache + +T = TypeVar("T") + + +class ModelCache(Generic[T]): + """TTL-based cache for models with async cleanup support. + + Uses cachetools.TTLCache internally but refreshes TTL on read access + (not just write), and provides methods for async cleanup before items + expire. + + This is designed for ML model caching where: + - Models should stay loaded while being actively used + - Idle models should be unloaded after a timeout + - Unloading requires calling an async cleanup method + + Example: + cache = ModelCache[BaseModel](ttl=300) # 5 minute TTL + + # Set a model + cache["encoder:model-id"] = model + + # Get model (refreshes TTL) + model = cache.get("encoder:model-id") + + # In cleanup task: + for key, model in cache.pop_expired(): + await model.unload() + """ + + def __init__(self, ttl: float, maxsize: int = 1000): + """Initialize the cache. + + Args: + ttl: Time-to-live in seconds. Items are considered expired + after this many seconds of inactivity (no read or write). + maxsize: Maximum number of items to store. + """ + self._ttl = ttl + self._maxsize = maxsize + # Internal TTLCache with very long TTL - we manage expiry ourselves + # to support async callbacks before removal + self._cache: TTLCache[str, T] = TTLCache(maxsize=maxsize, ttl=ttl * 10) + # Track access times ourselves for TTL-on-read behavior + self._timer = time.monotonic + self._access: dict[str, float] = {} + + @property + def ttl(self) -> float: + """Get the TTL in seconds.""" + return self._ttl + + def __contains__(self, key: str) -> bool: + return key in self._cache + + def __len__(self) -> int: + return len(self._cache) + + def __iter__(self) -> Iterator[str]: + return iter(self._cache) + + def get(self, key: str, default: T | None = None) -> T | None: + """Get item and refresh its TTL. + + Args: + key: Cache key + default: Value to return if key not found + + Returns: + The cached item, or default if not found + """ + if key not in self._cache: + return default + self._access[key] = self._timer() + return self._cache[key] + + def __getitem__(self, key: str) -> T: + """Get item and refresh TTL. Raises KeyError if not found.""" + if key not in self._cache: + raise KeyError(key) + self._access[key] = self._timer() + return self._cache[key] + + def __setitem__(self, key: str, value: T) -> None: + """Set item with fresh TTL.""" + self._cache[key] = value + self._access[key] = self._timer() + + def __delitem__(self, key: str) -> None: + """Remove item from cache.""" + del self._cache[key] + self._access.pop(key, None) + + def pop(self, key: str, *args) -> T: + """Remove and return item. + + Args: + key: Cache key + *args: Optional default value + + Returns: + The removed item, or default if provided and key not found + """ + self._access.pop(key, None) + return self._cache.pop(key, *args) + + def keys(self): + """Return view of cache keys.""" + return self._cache.keys() + + def values(self): + """Return view of cache values.""" + return self._cache.values() + + def items(self): + """Return view of cache items.""" + return self._cache.items() + + def clear(self) -> None: + """Clear all items from cache.""" + self._cache.clear() + self._access.clear() + + def get_idle_time(self, key: str) -> float | None: + """Get seconds since last access for a key. + + Args: + key: Cache key + + Returns: + Seconds since last access, or None if key not found + """ + if key not in self._access: + return None + return self._timer() - self._access[key] + + def is_expired(self, key: str) -> bool: + """Check if an item has exceeded its TTL. + + Args: + key: Cache key + + Returns: + True if item exists and is expired, False otherwise + """ + idle_time = self.get_idle_time(key) + return idle_time is not None and idle_time > self._ttl + + def get_expired_keys(self) -> list[str]: + """Get list of keys that have exceeded their TTL. + + Returns: + List of expired cache keys + """ + now = self._timer() + cutoff = now - self._ttl + return [k for k, t in self._access.items() if t < cutoff] + + def pop_expired(self) -> list[tuple[str, T]]: + """Remove and return all expired items. + + This is the main method for cleanup tasks. It returns all expired + items so the caller can perform async cleanup (like calling unload()). + + Returns: + List of (key, value) tuples for expired items + """ + expired_keys = self.get_expired_keys() + result = [] + for key in expired_keys: + if key in self._cache: + value = self._cache.pop(key) + self._access.pop(key, None) + result.append((key, value)) + return result diff --git a/common/llamafarm_common/model_format.py b/common/llamafarm_common/model_format.py new file mode 100644 index 000000000..2db9ba1ea --- /dev/null +++ b/common/llamafarm_common/model_format.py @@ -0,0 +1,172 @@ +"""Model format detection utilities. + +Detects whether a HuggingFace model repository contains GGUF or transformers format files. + +Note: Core GGUF utilities (list_gguf_files, select_gguf_file, get_gguf_file_path, etc.) +are provided by llamafarm_common.model_utils and re-exported here for backward compatibility. + +Performance optimizations: +- Results are cached to avoid repeated API calls within a session +- Checks local HuggingFace cache before making network requests +""" + +import logging + +from huggingface_hub import HfApi, scan_cache_dir +from huggingface_hub.utils import HFCacheInfo +from .model_utils import ( + GGUF_QUANTIZATION_PREFERENCE_ORDER, + get_gguf_file_path, + list_gguf_files, + parse_model_with_quantization, + parse_quantization_from_filename, + select_gguf_file, + select_gguf_file_with_logging, +) + +logger = logging.getLogger(__name__) + +# Cache detection results to avoid repeated filesystem checks +_format_cache: dict[str, str] = {} + +# Cache for local repo info to avoid repeated cache scans +_local_cache_info: HFCacheInfo | None = None + +# Re-export commonly used functions for backward compatibility +__all__ = [ + "GGUF_QUANTIZATION_PREFERENCE_ORDER", + "parse_model_with_quantization", + "parse_quantization_from_filename", + "select_gguf_file", + "select_gguf_file_with_logging", + "detect_model_format", + "list_gguf_files", + "get_gguf_file_path", + "clear_format_cache", +] + + +def _check_local_cache_for_model(model_id: str) -> list[str] | None: + """Check if model files are available in local HuggingFace cache. + + This avoids making network requests when we can determine format locally. + + Args: + model_id: HuggingFace model identifier + + Returns: + List of cached filenames if model is cached, None otherwise + """ + global _local_cache_info + + try: + # Scan cache once and reuse (scanning is ~10-50ms) + if _local_cache_info is None: + _local_cache_info = scan_cache_dir() + + # Look for this model in cache + for repo in _local_cache_info.repos: + if repo.repo_id == model_id and repo.repo_type == "model": + # Found cached repo - collect all filenames across revisions + filenames = set() + for revision in repo.revisions: + for file in revision.files: + filenames.add(file.file_name) + if filenames: + logger.debug( + f"Found {len(filenames)} files in local cache for {model_id}" + ) + return list(filenames) + + return None + + except Exception as e: + logger.debug(f"Could not scan local cache: {e}") + return None + + +def detect_model_format(model_id: str, token: str | None = None) -> str: + """ + Detect if a HuggingFace model is GGUF or transformers format. + + This function first checks if the model is in the local HuggingFace cache, + and only makes API calls if not cached locally. Results are cached in memory + to avoid repeated checks within a session. + + Args: + model_id: HuggingFace model identifier (e.g., "unsloth/Qwen3-0.6B-GGUF" or "unsloth/Qwen3-0.6B-GGUF:Q4_K_M") + token: Optional HuggingFace authentication token for gated models + + Returns: + "gguf" if model contains .gguf files, "transformers" otherwise + + Raises: + Exception: If model cannot be accessed + + Examples: + >>> detect_model_format("unsloth/Qwen3-0.6B-GGUF") + "gguf" + >>> detect_model_format("unsloth/Qwen3-0.6B-GGUF:Q4_K_M") + "gguf" + >>> detect_model_format("google/gemma-3-1b-it") + "transformers" + """ + # Parse model ID to remove quantization suffix if present + base_model_id, _ = parse_model_with_quantization(model_id) + + # Check memory cache first (fastest) + if base_model_id in _format_cache: + logger.debug( + f"Using cached format for {base_model_id}: {_format_cache[base_model_id]}" + ) + return _format_cache[base_model_id] + + logger.info(f"Detecting format for model: {base_model_id}") + + # Try local cache first to avoid API call + local_files = _check_local_cache_for_model(base_model_id) + if local_files is not None: + has_gguf = any(f.endswith(".gguf") for f in local_files) + if has_gguf: + logger.info("Detected GGUF format from local cache (found .gguf files)") + _format_cache[base_model_id] = "gguf" + return "gguf" + else: + logger.info( + "Detected transformers format from local cache (no .gguf files)" + ) + _format_cache[base_model_id] = "transformers" + return "transformers" + + # Not in local cache - must query API + try: + api = HfApi() + all_files = api.list_repo_files(repo_id=base_model_id, token=token) + + # Check if any .gguf files exist + has_gguf = any(f.endswith(".gguf") for f in all_files) + + if has_gguf: + logger.info("Detected GGUF format (found .gguf files)") + _format_cache[base_model_id] = "gguf" + return "gguf" + + # No GGUF files found - assume transformers format + logger.info("Detected transformers format (no .gguf files found)") + _format_cache[base_model_id] = "transformers" + return "transformers" + + except Exception as e: + logger.error(f"Error detecting model format for {base_model_id}: {e}") + raise + + +def clear_format_cache(): + """Clear the format detection cache. + + Useful for testing or when model repositories are updated. + """ + global _format_cache, _local_cache_info + _format_cache = {} + _local_cache_info = None + logger.debug("Format detection cache cleared") diff --git a/common/llamafarm_common/pidfile.py b/common/llamafarm_common/pidfile.py index aea8484cd..1d6d487d0 100644 --- a/common/llamafarm_common/pidfile.py +++ b/common/llamafarm_common/pidfile.py @@ -14,13 +14,9 @@ def get_pid_dir() -> Path: """Get the directory for PID files.""" - try: - _home = Path.home() - except RuntimeError: - _fb = os.environ.get("USERPROFILE") or os.environ.get("APPDATA") or os.environ.get("LOCALAPPDATA") - _home = Path(_fb) if _fb else Path.cwd() - lf_data_dir = os.getenv("LF_DATA_DIR", _home / ".llamafarm") - pid_dir = Path(lf_data_dir) / "pids" + from .safe_home import get_data_dir + + pid_dir = get_data_dir() / "pids" pid_dir.mkdir(parents=True, exist_ok=True) return pid_dir diff --git a/common/llamafarm_common/safe_home.py b/common/llamafarm_common/safe_home.py new file mode 100644 index 000000000..28c004c02 --- /dev/null +++ b/common/llamafarm_common/safe_home.py @@ -0,0 +1,34 @@ +"""Safe home directory resolution for embedded Python environments. + +Path.home() raises RuntimeError in PyApp-embedded Python on Windows +when HOME/USERPROFILE env vars are absent during bootstrap. +""" + +import os +from pathlib import Path + + +def safe_home() -> Path: + """Return the user's home directory with fallback for embedded Python.""" + try: + return Path.home() + except RuntimeError: + fb = ( + os.environ.get("USERPROFILE") + or os.environ.get("APPDATA") + or os.environ.get("LOCALAPPDATA") + ) + if fb: + return Path(fb) + try: + return Path.cwd() + except OSError: + return Path(".") + + +def get_data_dir() -> Path: + """Return the LlamaFarm data directory (LF_DATA_DIR or ~/.llamafarm).""" + env = os.environ.get("LF_DATA_DIR") + if env: + return Path(env) + return safe_home() / ".llamafarm" diff --git a/common/pyproject.toml b/common/pyproject.toml index e705b5441..a50fdd4df 100644 --- a/common/pyproject.toml +++ b/common/pyproject.toml @@ -7,6 +7,7 @@ dependencies = [ "huggingface_hub>=0.24.0", "hf-transfer>=0.1.9", # High-speed downloads (set HF_HUB_ENABLE_HF_TRANSFER=1) "filelock>=3.16.1", # Pinned <=3.20.0 by PyTorch CPU index; don't raise above + "cachetools>=6.0.0", # TTL-based model caching (used by model_cache module) ] [project.optional-dependencies] diff --git a/common/uv.lock b/common/uv.lock index 0376df407..d26de5116 100644 --- a/common/uv.lock +++ b/common/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", @@ -28,6 +28,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" }, ] +[[package]] +name = "cachetools" +version = "7.0.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/dd/57fe3fdb6e65b25a5987fd2cdc7e22db0aef508b91634d2e57d22928d41b/cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990", size = 37367, upload-time = "2026-03-09T20:51:29.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, +] + [[package]] name = "certifi" version = "2025.11.12" @@ -103,15 +112,23 @@ version = "0.1.9" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/1a/eb/8fc64f40388c29ce8ce3b2b180a089d4d6b25b1d0d232d016704cb852104/hf_transfer-0.1.9.tar.gz", hash = "sha256:035572865dab29d17e783fbf1e84cf1cb24f3fcf8f1b17db1cfc7fdf139f02bf", size = 25201, upload-time = "2025-01-07T10:05:12.947Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/78/0dce00208f585fae675f40033ef9a30dedfa83665d5ac79f16beb4a0a6c2/hf_transfer-0.1.9-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:6e94e8822da79573c9b6ae4d6b2f847c59a7a06c5327d7db20751b68538dc4f6", size = 1386084, upload-time = "2025-01-07T10:04:47.874Z" }, { url = "https://files.pythonhosted.org/packages/ea/2e/3d60b1a9e9f29a2152aa66c823bf5e399ae7be3fef310ff0de86779c5d2d/hf_transfer-0.1.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ebc4ab9023414880c8b1d3c38174d1c9989eb5022d37e814fa91a3060123eb0", size = 1343558, upload-time = "2025-01-07T10:04:42.313Z" }, { url = "https://files.pythonhosted.org/packages/fb/38/130a5ac3747f104033591bcac1c961cb1faadfdc91704f59b09c0b465ff2/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8674026f21ed369aa2a0a4b46000aca850fc44cd2b54af33a172ce5325b4fc82", size = 3726676, upload-time = "2025-01-07T10:04:11.539Z" }, + { url = "https://files.pythonhosted.org/packages/15/a1/f4e27c5ad17aac616ae0849e2aede5aae31db8267a948c6b3eeb9fd96446/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a736dfbb2c84f5a2c975478ad200c0c8bfcb58a25a35db402678fb87ce17fa4", size = 3062920, upload-time = "2025-01-07T10:04:16.297Z" }, + { url = "https://files.pythonhosted.org/packages/50/d0/2b213eb1ea8b1252ccaf1a6c804d0aba03fea38aae4124df6a3acb70511a/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c7fc1b85f4d0f76e452765d7648c9f4bfd0aedb9ced2ae1ebfece2d8cfaf8e2", size = 3398837, upload-time = "2025-01-07T10:04:22.778Z" }, { url = "https://files.pythonhosted.org/packages/8c/8a/79dbce9006e0bd6b74516f97451a7b7c64dbbb426df15d901dd438cfeee3/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d991376f0eac70a60f0cbc95602aa708a6f7c8617f28b4945c1431d67b8e3c8", size = 3546986, upload-time = "2025-01-07T10:04:36.415Z" }, { url = "https://files.pythonhosted.org/packages/a9/f7/9ac239b6ee6fe0bad130325d987a93ea58c4118e50479f0786f1733b37e8/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e6ac4eddcd99575ed3735ed911ddf9d1697e2bd13aa3f0ad7e3904dd4863842e", size = 4071715, upload-time = "2025-01-07T10:04:53.224Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/0ed697279f5eeb7a40f279bd783cf50e6d0b91f24120dcf66ef2cf8822b4/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:57fd9880da1ee0f47250f735f791fab788f0aa1ee36afc49f761349869c8b4d9", size = 3388081, upload-time = "2025-01-07T10:04:57.818Z" }, { url = "https://files.pythonhosted.org/packages/45/07/6661e43fbee09594a8a5e9bb778107d95fe38dac4c653982afe03d32bd4d/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a5b366d34cd449fe9b20ef25941e6eef0460a2f74e7389f02e673e1f88ebd538", size = 3690551, upload-time = "2025-01-07T10:05:09.238Z" }, + { url = "https://files.pythonhosted.org/packages/81/f5/461d2e5f307e5048289b1168d5c642ae3bb2504e88dff1a38b92ed990a21/hf_transfer-0.1.9-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e66acf91df4a8b72f60223059df3003062a5ae111757187ed1a06750a30e911b", size = 1393046, upload-time = "2025-01-07T10:04:51.003Z" }, { url = "https://files.pythonhosted.org/packages/41/ba/8d9fd9f1083525edfcb389c93738c802f3559cb749324090d7109c8bf4c2/hf_transfer-0.1.9-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8669dbcc7a3e2e8d61d42cd24da9c50d57770bd74b445c65123291ca842a7e7a", size = 1348126, upload-time = "2025-01-07T10:04:45.712Z" }, { url = "https://files.pythonhosted.org/packages/8e/a2/cd7885bc9959421065a6fae0fe67b6c55becdeda4e69b873e52976f9a9f0/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd0167c4407a3bc4cdd0307e65ada2294ec04f1813d8a69a5243e379b22e9d8", size = 3728604, upload-time = "2025-01-07T10:04:14.173Z" }, + { url = "https://files.pythonhosted.org/packages/f6/2e/a072cf196edfeda3310c9a5ade0a0fdd785e6154b3ce24fc738c818da2a7/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee8b10afedcb75f71091bcc197c526a6ebf5c58bbbadb34fdeee6160f55f619f", size = 3064995, upload-time = "2025-01-07T10:04:18.663Z" }, + { url = "https://files.pythonhosted.org/packages/29/63/b560d39651a56603d64f1a0212d0472a44cbd965db2fa62b99d99cb981bf/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc6bd19e1cc177c66bdef15ef8636ad3bde79d5a4f608c158021153b4573509d", size = 3400839, upload-time = "2025-01-07T10:04:26.122Z" }, { url = "https://files.pythonhosted.org/packages/d6/d8/f87ea6f42456254b48915970ed98e993110521e9263472840174d32c880d/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdca9bfb89e6f8f281890cc61a8aff2d3cecaff7e1a4d275574d96ca70098557", size = 3552664, upload-time = "2025-01-07T10:04:40.123Z" }, { url = "https://files.pythonhosted.org/packages/d6/56/1267c39b65fc8f4e2113b36297320f102718bf5799b544a6cbe22013aa1d/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:89a23f58b7b7effbc047b8ca286f131b17728c99a9f972723323003ffd1bb916", size = 4073732, upload-time = "2025-01-07T10:04:55.624Z" }, + { url = "https://files.pythonhosted.org/packages/82/1a/9c748befbe3decf7cb415e34f8a0c3789a0a9c55910dea73d581e48c0ce5/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:dc7fff1345980d6c0ebb92c811d24afa4b98b3e07ed070c8e38cc91fd80478c5", size = 3390096, upload-time = "2025-01-07T10:04:59.98Z" }, { url = "https://files.pythonhosted.org/packages/e7/6e/e597b04f753f1b09e6893075d53a82a30c13855cbaa791402695b01e369f/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d2fde99d502093ade3ab1b53f80da18480e9902aa960dab7f74fb1b9e5bc5746", size = 3695243, upload-time = "2025-01-07T10:05:11.411Z" }, { url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240, upload-time = "2025-01-07T10:05:14.324Z" }, ] @@ -122,18 +139,21 @@ version = "1.2.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" }, { url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" }, { url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" }, { url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" }, { url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" }, { url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" }, { url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" }, + { url = "https://files.pythonhosted.org/packages/e2/51/f7e2caae42f80af886db414d4e9885fac959330509089f97cccb339c6b87/hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e", size = 2861861, upload-time = "2025-10-24T19:04:19.01Z" }, { url = "https://files.pythonhosted.org/packages/6e/1d/a641a88b69994f9371bd347f1dd35e5d1e2e2460a2e350c8d5165fc62005/hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8", size = 2717699, upload-time = "2025-10-24T19:04:17.306Z" }, { url = "https://files.pythonhosted.org/packages/df/e0/e5e9bba7d15f0318955f7ec3f4af13f92e773fbb368c0b8008a5acbcb12f/hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0", size = 3314885, upload-time = "2025-10-24T19:04:07.642Z" }, { url = "https://files.pythonhosted.org/packages/21/90/b7fe5ff6f2b7b8cbdf1bd56145f863c90a5807d9758a549bf3d916aa4dec/hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090", size = 3221550, upload-time = "2025-10-24T19:04:05.55Z" }, { url = "https://files.pythonhosted.org/packages/6f/cb/73f276f0a7ce46cc6a6ec7d6c7d61cbfe5f2e107123d9bbd0193c355f106/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a", size = 3408010, upload-time = "2025-10-24T19:04:28.598Z" }, { url = "https://files.pythonhosted.org/packages/b8/1e/d642a12caa78171f4be64f7cd9c40e3ca5279d055d0873188a58c0f5fbb9/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f", size = 3503264, upload-time = "2025-10-24T19:04:30.397Z" }, { url = "https://files.pythonhosted.org/packages/17/b5/33764714923fa1ff922770f7ed18c2daae034d21ae6e10dbf4347c854154/hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc", size = 2901071, upload-time = "2025-10-24T19:04:37.463Z" }, + { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, @@ -214,6 +234,7 @@ name = "llamafarm-common" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "hf-transfer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, @@ -231,6 +252,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.0.0" }, { name = "filelock", specifier = ">=3.16.1" }, { name = "hf-transfer", specifier = ">=0.1.9" }, { name = "huggingface-hub", specifier = ">=0.24.0" }, @@ -292,42 +314,58 @@ version = "6.0.3" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/a0/39350dd17dd6d6c6507025c0e53aef67a9293a6d37d3511f23ea510d5800/pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b", size = 184227, upload-time = "2025-09-25T21:31:46.04Z" }, { url = "https://files.pythonhosted.org/packages/05/14/52d505b5c59ce73244f59c7a50ecf47093ce4765f116cdb98286a71eeca2/pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956", size = 174019, upload-time = "2025-09-25T21:31:47.706Z" }, { url = "https://files.pythonhosted.org/packages/43/f7/0e6a5ae5599c838c696adb4e6330a59f463265bfa1e116cfd1fbb0abaaae/pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8", size = 740646, upload-time = "2025-09-25T21:31:49.21Z" }, + { url = "https://files.pythonhosted.org/packages/2f/3a/61b9db1d28f00f8fd0ae760459a5c4bf1b941baf714e207b6eb0657d2578/pyyaml-6.0.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:66291b10affd76d76f54fad28e22e51719ef9ba22b29e1d7d03d6777a9174198", size = 840793, upload-time = "2025-09-25T21:31:50.735Z" }, { url = "https://files.pythonhosted.org/packages/7a/1e/7acc4f0e74c4b3d9531e24739e0ab832a5edf40e64fbae1a9c01941cabd7/pyyaml-6.0.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c7708761fccb9397fe64bbc0395abcae8c4bf7b0eac081e12b809bf47700d0b", size = 770293, upload-time = "2025-09-25T21:31:51.828Z" }, { url = "https://files.pythonhosted.org/packages/8b/ef/abd085f06853af0cd59fa5f913d61a8eab65d7639ff2a658d18a25d6a89d/pyyaml-6.0.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:418cf3f2111bc80e0933b2cd8cd04f286338bb88bdc7bc8e6dd775ebde60b5e0", size = 732872, upload-time = "2025-09-25T21:31:53.282Z" }, { url = "https://files.pythonhosted.org/packages/1f/15/2bc9c8faf6450a8b3c9fc5448ed869c599c0a74ba2669772b1f3a0040180/pyyaml-6.0.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5e0b74767e5f8c593e8c9b5912019159ed0533c70051e9cce3e8b6aa699fcd69", size = 758828, upload-time = "2025-09-25T21:31:54.807Z" }, { url = "https://files.pythonhosted.org/packages/2a/fa/926c003379b19fca39dd4634818b00dec6c62d87faf628d1394e137354d4/pyyaml-6.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:bdb2c67c6c1390b63c6ff89f210c8fd09d9a1217a465701eac7316313c915e4c", size = 158561, upload-time = "2025-09-25T21:31:57.406Z" }, + { url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" }, { url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" }, { url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" }, { url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" }, { url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" }, { url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" }, { url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" }, + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" }, { url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" }, { url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" }, { url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" }, { url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" }, { url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" }, { url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" }, { url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" }, { url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" }, { url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" }, { url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" }, { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, { url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] [[package]] @@ -345,30 +383,35 @@ version = "2.3.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, + { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, + { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, + { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, diff --git a/packages/llamafarm-llama/src/llamafarm_llama/_binary.py b/packages/llamafarm-llama/src/llamafarm_llama/_binary.py index 5ced88aa6..f960e666a 100644 --- a/packages/llamafarm-llama/src/llamafarm_llama/_binary.py +++ b/packages/llamafarm-llama/src/llamafarm_llama/_binary.py @@ -42,15 +42,50 @@ def _read_llama_cpp_version() -> str: def _get_llamafarm_release_version() -> str: - """Get LlamaFarm release version for ARM64 binary downloads.""" + """Get LlamaFarm release version for ARM64 binary downloads. + + The ARM64 llama.cpp binary is published as part of the main LlamaFarm + monorepo release (e.g., v0.0.28), NOT the llamafarm-llama package version. + These versions are decoupled. + + Priority: + 1. LLAMAFARM_RELEASE_VERSION env var (explicit override) + 2. GitHub API query for latest release with the ARM64 binary + 3. Hardcoded fallback + """ + # 1. Env var override + env_version = os.environ.get("LLAMAFARM_RELEASE_VERSION") + if env_version: + if not env_version.startswith("v"): + env_version = f"v{env_version}" + logger.info(f"Using LlamaFarm release version from env: {env_version}") + return env_version + + # 2. Query GitHub API for latest release with ARM64 binary try: - version = metadata.version("llamafarm-llama") - if version and not version.startswith("0.0.0"): - return f"v{version}" - except metadata.PackageNotFoundError: - pass - # Fallback for dev installs - return "v0.0.1" + import json + + req = Request( + "https://api.github.com/repos/llama-farm/llamafarm/releases/latest", + headers={"User-Agent": "llamafarm-llama", "Accept": "application/vnd.github.v3+json"}, + ) + with urlopen(req, timeout=10) as response: + data = json.loads(response.read()) + tag = data.get("tag_name") + assets = data.get("assets", []) + asset_names = [a.get("name", "") for a in assets] + if tag and any("arm64" in name for name in asset_names): + logger.info(f"Using latest LlamaFarm release: {tag}") + return tag + elif tag: + logger.debug(f"Latest release {tag} has no ARM64 asset, skipping") + except Exception as e: + logger.debug(f"Could not query GitHub for latest release: {e}") + + # 3. Hardcoded fallback (last known good release with ARM64 binary) + fallback = "v0.0.28" + logger.info(f"Using fallback LlamaFarm release version: {fallback}") + return fallback # Binary URLs from llama.cpp GitHub releases # Format: https://github.com/ggml-org/llama.cpp/releases/download/{version}/{artifact} @@ -70,7 +105,7 @@ def _get_llamafarm_release_version() -> str: }, # Linux ARM64 (LlamaFarm provided - not available from upstream) ("linux", "arm64", "cpu"): { - "artifact": "https://github.com/llama-farm/llamafarm/releases/download/{llamafarm_version}/llama-{version}-bin-linux-arm64.tar.gz", + "artifact": "https://github.com/llama-farm/llamafarm/releases/download/{llamafarm_version}/llama-{version}-bin-linux-arm64.zip", "lib": "libllama.so", "sha256": None, }, diff --git a/rag/uv.lock b/rag/uv.lock index f83c06ccf..4c80a75aa 100644 --- a/rag/uv.lock +++ b/rag/uv.lock @@ -2297,6 +2297,7 @@ name = "llamafarm-common" version = "0.1.0" source = { editable = "../common" } dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "hf-transfer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, @@ -2304,6 +2305,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.0.0" }, { name = "filelock", specifier = ">=3.16.1" }, { name = "hf-transfer", specifier = ">=0.1.9" }, { name = "huggingface-hub", specifier = ">=0.24.0" }, @@ -5323,6 +5325,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, diff --git a/runtimes/edge/Dockerfile b/runtimes/edge/Dockerfile new file mode 100644 index 000000000..5d74baffb --- /dev/null +++ b/runtimes/edge/Dockerfile @@ -0,0 +1,105 @@ +# ============================================================ +# Builder stage — install build tools and compile dependencies +# ============================================================ +FROM ubuntu:24.04 AS builder + +ENV DEBIAN_FRONTEND=noninteractive + +# Build-time option: set to "false" to skip vision deps (ultralytics/YOLO) +ARG ENABLE_VISION=true + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-venv \ + python3-dev \ + build-essential \ + cmake \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Create a self-contained venv we can COPY to the runtime stage +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install uv for fast dependency management +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +WORKDIR /app + +# Copy local dependency sources +COPY common /deps/common +COPY packages/llamafarm-llama /deps/llamafarm-llama + +# Copy edge runtime pyproject.toml for dependency resolution +COPY runtimes/edge/pyproject.toml ./ + +# Install CPU-only PyTorch first (saves ~2GB on AMD64 vs CUDA default). +# ARM64 PyPI wheels are already CPU-only, so this is a no-op there. +RUN uv pip install --no-cache \ + torch torchvision \ + --extra-index-url https://download.pytorch.org/whl/cpu + +# Install local deps as non-editable (no need to copy /deps/ to runtime) +RUN uv pip install --no-cache /deps/common /deps/llamafarm-llama + +# Install edge runtime deps (vision conditional) +RUN if [ "$ENABLE_VISION" = "true" ]; then \ + uv pip install --no-cache --no-sources ".[vision]" && \ + uv pip install --no-cache pi-heif; \ + else \ + uv pip install --no-cache --no-sources "."; \ + fi + +# Copy application code +COPY runtimes/edge/ . + +# ============================================================ +# Runtime stage — minimal image with only what's needed to run +# ============================================================ +FROM ubuntu:24.04 + +ENV DEBIAN_FRONTEND=noninteractive + +# Runtime-only system libraries (no build-essential, cmake, git, python3-dev) +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-venv \ + libgl1 \ + libglib2.0-0 \ + libxcb1 \ + && ln -sf /usr/bin/python3 /usr/bin/python \ + && rm -rf /var/lib/apt/lists/* + +# Copy the pre-built venv from builder +COPY --from=builder /opt/venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Copy application code +WORKDIR /app +COPY --from=builder /app/ ./ + +# Create non-root user for security +RUN useradd --create-home --shell /bin/bash edge && \ + chown -R edge:edge /app +USER edge + +# Create data directory +RUN mkdir -p /home/edge/.llamafarm + +# Health check +HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:11540/health')" || exit 1 + +# Expose port +EXPOSE 11540 + +# Environment defaults +ENV LF_RUNTIME_PORT=11540 \ + LF_RUNTIME_HOST=0.0.0.0 \ + LOG_LEVEL=INFO \ + MODEL_UNLOAD_TIMEOUT=300 \ + YOLO_AUTOINSTALL=false + +# Run the edge runtime +CMD ["python", "server.py"] diff --git a/runtimes/edge/config/model_context_defaults.yaml b/runtimes/edge/config/model_context_defaults.yaml new file mode 100644 index 000000000..cc97cdabd --- /dev/null +++ b/runtimes/edge/config/model_context_defaults.yaml @@ -0,0 +1,34 @@ +# Default context sizes for GGUF models +# Patterns use Unix shell-style wildcards (*, ?, [seq]) +# More specific patterns should be listed first + +# Memory usage factor for computing max context size +# 0.8 = use 80% of available memory (aggressive but safe for most systems) +memory_usage_factor: 0.8 + +model_defaults: + # Exact model matches (highest priority) + - pattern: "unsloth/Qwen2.5-Coder-1.5B-Instruct-GGUF" + n_ctx: 32768 + notes: "Qwen 2.5 supports 32k context" + + - pattern: "unsloth/gpt-oss-*" + n_ctx: 8192 + notes: "GPT-OSS models default to 8k context" + + # Wildcard patterns (lower priority) + - pattern: "*Qwen2.5*" + n_ctx: 32768 + notes: "Qwen 2.5 family supports 32k context" + + - pattern: "*Llama-3*" + n_ctx: 8192 + notes: "Llama 3 family default" + + - pattern: "*Mistral*" + n_ctx: 32768 + notes: "Mistral models support 32k context" + + - pattern: "*" + n_ctx: 4096 + notes: "Fallback default for unknown models" diff --git a/runtimes/edge/core/__init__.py b/runtimes/edge/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/runtimes/edge/core/logging.py b/runtimes/edge/core/logging.py new file mode 100644 index 000000000..6f6ede255 --- /dev/null +++ b/runtimes/edge/core/logging.py @@ -0,0 +1,156 @@ +# src/core/logger.py +import logging +from typing import Any + +import structlog +from structlog.types import EventDict, Processor + + +def _coerce_log_level(level: Any) -> int | str: + """Allow level as int, numeric string, or name. + + Returns an int for numeric inputs; otherwise an upper-cased level name. + """ + if isinstance(level, int): + return level + if isinstance(level, str): + s = level.strip() + if s.isdigit(): + try: + return int(s) + except Exception: + return s.upper() + return s.upper() + return level + + +def drop_color_message_key(_, __, event_dict: EventDict) -> EventDict: + """ + Uvicorn logs the message a second time in the extra `color_message`, but we don't + need it. This processor drops the key from the event dict if it exists. + """ + event_dict.pop("color_message", None) + return event_dict + + +def setup_logging(json_logs: bool = False, log_level: str = "INFO", log_file: str = ""): + """Setup logging with structlog, similar to server/core/logging.py.""" + timestamper = structlog.processors.TimeStamper(fmt="iso") + + shared_processors: list[Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.stdlib.ExtraAdder(), + drop_color_message_key, + timestamper, + structlog.processors.StackInfoRenderer(), + ] + + if json_logs: + # Format the exception only for JSON logs, as we want to pretty-print them when + # using the ConsoleRenderer + shared_processors.append(structlog.processors.format_exc_info) + + structlog.configure( + processors=shared_processors + + [ + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + log_renderer: structlog.types.Processor + if json_logs: + log_renderer = structlog.processors.JSONRenderer() + else: + log_renderer = structlog.dev.ConsoleRenderer( + exception_formatter=structlog.dev.plain_traceback + ) + + formatter = structlog.stdlib.ProcessorFormatter( + # These run ONLY on `logging` entries that do NOT originate within + # structlog. + foreign_pre_chain=shared_processors, + # These run on ALL entries after the pre_chain is done. + processors=[ + # Remove _record & _from_structlog. + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + log_renderer, + ], + ) + + # Clear all existing handlers from root logger to prevent duplication + root_logger = logging.getLogger() + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Add console handler (stdout) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # Add file handler if LOG_FILE is specified + if log_file: + try: + # Ensure parent directory exists + from pathlib import Path + + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + + # Log that file logging is enabled + root_logger.info(f"File logging enabled: {log_file}") + except Exception as e: + # If file logging fails, log to console but don't crash + root_logger.error(f"Failed to set up file logging to {log_file}: {e}") + + root_logger.setLevel(_coerce_log_level(log_level)) + + # Always use info level for httpcore.xxx logs + for logger_name in ["httpcore.connection", "httpcore.http11"]: + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Configure uvicorn loggers to use our root logger setup + for logger_name in ["uvicorn", "uvicorn.error"]: + uvicorn_logger = logging.getLogger(logger_name) + # Clear any existing handlers to prevent duplication + for handler in uvicorn_logger.handlers[:]: + uvicorn_logger.removeHandler(handler) + # Let logs propagate to root logger (which has our structlog handler) + uvicorn_logger.name = "uvicorn" + uvicorn_logger.setLevel(_coerce_log_level(log_level)) + + +class UniversalRuntimeLogger: + """Logger wrapper for universal runtime, similar to FastAPIStructLogger.""" + + def __init__(self, log_name: str = "universal-runtime"): + self.logger = structlog.stdlib.get_logger(log_name) + + def debug(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.debug(event, *args, **kw) + + def info(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.info(event, *args, **kw) + + def warning(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.warning(event, *args, **kw) + + warn = warning + + def error(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.error(event, *args, **kw) + + def critical(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.critical(event, *args, **kw) + + def exception(self, event: str | None = None, *args: Any, **kw: Any): + self.logger.exception(event, *args, **kw) diff --git a/runtimes/edge/models/__init__.py b/runtimes/edge/models/__init__.py new file mode 100644 index 000000000..1e8ee4d42 --- /dev/null +++ b/runtimes/edge/models/__init__.py @@ -0,0 +1,45 @@ +""" +Model wrappers for Edge Runtime. + +Only includes model types needed for edge inference: +- Language models (GGUF and transformers) +- Vision models (YOLO detection, CLIP classification) +""" + +from .base import BaseModel +from .clip_model import CLIPModel +from .gguf_language_model import GGUFLanguageModel +from .language_model import LanguageModel +from .vision_base import ( + ClassificationModel, + ClassificationResult, + DetectionBox, + DetectionModel, + DetectionResult, + EmbeddingResult, + VisionModel, + VisionResult, +) +from .yolo_model import YOLOModel + +try: + from .hailo_model import HailoYOLOModel +except ImportError: + HailoYOLOModel = None # type: ignore[assignment,misc] + +__all__ = [ + "BaseModel", + "LanguageModel", + "GGUFLanguageModel", + "YOLOModel", + "HailoYOLOModel", + "CLIPModel", + "VisionModel", + "DetectionModel", + "ClassificationModel", + "VisionResult", + "DetectionBox", + "DetectionResult", + "ClassificationResult", + "EmbeddingResult", +] diff --git a/runtimes/edge/models/base.py b/runtimes/edge/models/base.py new file mode 100644 index 000000000..88de4f14c --- /dev/null +++ b/runtimes/edge/models/base.py @@ -0,0 +1,156 @@ +""" +Base model class for all HuggingFace models (transformers & diffusers). +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import torch + from transformers import PreTrainedTokenizerBase + +logger = logging.getLogger(__name__) + + +class BaseModel(ABC): + """Base class for all model types (transformers, diffusers, etc.).""" + + def __init__(self, model_id: str, device: str, token: str | None = None): + self.model_id = model_id + self.device = device + self.token = token # HuggingFace authentication token + self.model: Any | None = None + self.tokenizer: PreTrainedTokenizerBase | None = None + self.processor: Any | None = None # For vision/audio models + self.feature_extractor: Any | None = None # For audio models + self.pipe: Any | None = None # For diffusion models + self.model_type = "unknown" + self.supports_streaming = False + + @abstractmethod + async def load(self) -> None: + """Load the model and associated components.""" + pass + + async def unload(self) -> None: + """Unload the model and free resources. + + Default implementation for transformers models. Subclasses should override + if they need custom cleanup (e.g., GGUF models with llama-cpp). + """ + logger.info(f"Unloading model: {self.model_id}") + + # Move model to CPU to free GPU memory + if self.model is not None and hasattr(self.model, "to"): + try: + self.model = self.model.to("cpu") + except Exception as e: + logger.warning(f"Could not move model to CPU: {e}") + + # Clear references + self.model = None + self.tokenizer = None + self.processor = None + self.feature_extractor = None + self.pipe = None + + # Clear GPU cache if torch is available + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.debug("Cleared CUDA cache") + + if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): + try: + torch.mps.empty_cache() + logger.debug("Cleared MPS cache") + except Exception: + # MPS cache clearing can fail on some macOS versions; not critical + pass + except ImportError: + # torch not installed (GGUF-only deployment) + pass + + logger.info(f"Model unloaded: {self.model_id}") + + def get_model_info(self) -> dict[str, Any]: + """Get information about the loaded model.""" + return { + "model_id": self.model_id, + "model_type": self.model_type, + "device": self.device, + "supports_streaming": self.supports_streaming, + } + + def get_dtype(self, force_float32: bool = False): + """Get optimal torch dtype for the device. + + Args: + force_float32: Force float32 for models with MPS compatibility issues + """ + import torch + + if force_float32: + return torch.float32 + if self.device == "cuda" or self.device == "mps": + return torch.float16 + else: + return torch.float32 + + def to_device(self, tensor: torch.Tensor, dtype: torch.dtype | None = None): + """Move tensor to device with correct dtype. + + This helper ensures tensors are moved to device with matching dtype + to avoid MPS mixed precision issues. + + Args: + tensor: Tensor to move + dtype: Optional dtype override. If None, only moves to device without + changing dtype for integer tensors, or uses get_dtype() for floats. + """ + import torch + + # Don't change dtype for integer tensors (e.g., input_ids, attention_mask) + if tensor.dtype in ( + torch.int32, + torch.int64, + torch.long, + torch.int, + torch.bool, + ): + return tensor.to(device=self.device) + + if dtype is None: + dtype = self.get_dtype() + return tensor.to(device=self.device, dtype=dtype) + + def apply_optimizations(self): + """Apply platform-specific optimizations.""" + if self.pipe is None: + return + + try: + if self.device == "mps": + # MPS optimizations + self.pipe.enable_attention_slicing() + logger.info("Enabled attention slicing for MPS") + elif self.device == "cuda": + # CUDA optimizations + try: + self.pipe.enable_xformers_memory_efficient_attention() + logger.info("Enabled xformers memory efficient attention") + except Exception: + logger.warning("xformers not available, skipping") + + try: + self.pipe.enable_model_cpu_offload() + logger.info("Enabled model CPU offload") + except Exception as e: + logger.warning(f"Could not enable model CPU offload: {e}") + except Exception as e: + logger.warning(f"Could not apply optimizations: {e}") diff --git a/runtimes/edge/models/clip_model.py b/runtimes/edge/models/clip_model.py new file mode 100644 index 000000000..c4fa5ac5b --- /dev/null +++ b/runtimes/edge/models/clip_model.py @@ -0,0 +1,197 @@ +"""CLIP-based image classification and embedding model.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import TYPE_CHECKING + +import numpy as np + +from .vision_base import ClassificationModel, ClassificationResult, EmbeddingResult + +if TYPE_CHECKING: + import torch + from transformers import AutoModel, AutoProcessor + +logger = logging.getLogger(__name__) + +CLIP_VARIANTS = { + "clip-vit-base": "openai/clip-vit-base-patch32", + "clip-vit-large": "openai/clip-vit-large-patch14", + "siglip-base": "google/siglip-base-patch16-224", + "siglip-large": "google/siglip-large-patch16-256", +} + + +class CLIPModel(ClassificationModel): + """CLIP-based classifier with zero-shot classification and embedding support.""" + + def __init__(self, model_id: str = "clip-vit-base", device: str = "auto", + token: str | None = None, prompt_template: str = "a photo of a {}"): + super().__init__(model_id, device, token) + self.prompt_template = prompt_template + self.clip_model: AutoModel | None = None + self.processor: AutoProcessor | None = None + self._class_embeddings: torch.Tensor | None = None + self._embedding_dim: int = 0 + self._cached_class_key: tuple | None = None + self._class_lock = asyncio.Lock() + + async def load(self) -> None: + if self._loaded: + return + from transformers import AutoModel, AutoProcessor + + self.device = self._resolve_device(self.device) + logger.info(f"Loading CLIP model {self.model_id} on {self.device}") + start = time.perf_counter() + + hf_id = CLIP_VARIANTS.get(self.model_id, self.model_id) + + def _load(): + model = AutoModel.from_pretrained(hf_id, token=self.token) + proc = AutoProcessor.from_pretrained(hf_id, token=self.token) + model = model.to(self.device) + model.eval() + return model, proc + + self.clip_model, self.processor = await asyncio.to_thread(_load) + self._embedding_dim = getattr(self.clip_model.config, 'projection_dim', None) or getattr(self.clip_model.config, 'hidden_size', 512) + self._loaded = True + logger.info(f"CLIP loaded in {(time.perf_counter() - start) * 1000:.0f}ms (dim={self._embedding_dim})") + + async def unload(self) -> None: + self.clip_model = None + self.processor = None + self._class_embeddings = None + self._loaded = False + await super().unload() + + async def _encode_classes(self, class_names: list[str]) -> tuple: + """Pre-compute text embeddings for class names. + + Returns (class_names, embeddings) so callers can use them without + sharing mutable instance state across concurrent requests. + """ + import torch + class_key = tuple(class_names) + # Cache check: skip re-encoding if same classes and embeddings exist + if class_key == self._cached_class_key and self._class_embeddings is not None: + return class_names, self._class_embeddings + + prompts = [self.prompt_template.format(n) for n in class_names] + + def _encode(): + inputs = self.processor(text=prompts, return_tensors="pt", + padding=True, truncation=True).to(self.device) + with torch.no_grad(): + feats = self.clip_model.get_text_features(**inputs) + return feats / feats.norm(dim=-1, keepdim=True) + + embeddings = await asyncio.to_thread(_encode) + # Update shared cache for future requests with the same classes + self._class_embeddings = embeddings + self._cached_class_key = class_key + self.class_names = class_names + return class_names, embeddings + + async def classify(self, image: bytes | np.ndarray, + classes: list[str] | None = None, + top_k: int = 5) -> ClassificationResult: + if not self._loaded: + await self.load() + import torch + + # Resolve class names and embeddings for this request. + # Use local variables to avoid races from concurrent calls. + if classes is not None: + if not classes: + raise ValueError("Empty classes list provided.") + async with self._class_lock: + req_class_names, req_embeddings = await self._encode_classes(classes) + elif self._class_embeddings is not None and self._cached_class_key is not None: + req_class_names = list(self._cached_class_key) + req_embeddings = self._class_embeddings + else: + raise ValueError("No classes provided.") + + start = time.perf_counter() + pil_image = self._image_to_pil(image) + + def _infer(): + inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device) + with torch.no_grad(): + feats = self.clip_model.get_image_features(**inputs) + feats = feats / feats.norm(dim=-1, keepdim=True) + sim = (feats @ req_embeddings.T).squeeze() + if sim.ndim == 0: + sim = sim.unsqueeze(0) + return sim.softmax(dim=-1).cpu().numpy() + + probs = await asyncio.to_thread(_infer) + inference_time = (time.perf_counter() - start) * 1000 + + effective_k = min(top_k, len(req_class_names)) + top_idx = np.argsort(probs)[::-1][:effective_k] + best = int(top_idx[0]) + + return ClassificationResult( + confidence=float(probs[best]), + inference_time_ms=inference_time, + model_name=self.model_id, + class_name=req_class_names[best], + class_id=best, + all_scores={req_class_names[i]: float(probs[i]) for i in top_idx}, + ) + + async def embed_images(self, images: list[bytes | np.ndarray]) -> EmbeddingResult: + """Generate embeddings for images.""" + if not self._loaded: + await self.load() + import torch + + start = time.perf_counter() + pil_images = [self._image_to_pil(img) for img in images] + + def _embed(): + inputs = self.processor(images=pil_images, return_tensors="pt").to(self.device) + with torch.no_grad(): + feats = self.clip_model.get_image_features(**inputs) + feats = feats / feats.norm(dim=-1, keepdim=True) + return feats.cpu().numpy().tolist() + + embeddings = await asyncio.to_thread(_embed) + return EmbeddingResult( + confidence=1.0, inference_time_ms=(time.perf_counter() - start) * 1000, + model_name=self.model_id, embeddings=embeddings, dimensions=self._embedding_dim, + ) + + async def embed_texts(self, texts: list[str]) -> EmbeddingResult: + """Generate embeddings for texts.""" + if not self._loaded: + await self.load() + import torch + + start = time.perf_counter() + + def _embed(): + inputs = self.processor(text=texts, return_tensors="pt", + padding=True, truncation=True).to(self.device) + with torch.no_grad(): + feats = self.clip_model.get_text_features(**inputs) + feats = feats / feats.norm(dim=-1, keepdim=True) + return feats.cpu().numpy().tolist() + + embeddings = await asyncio.to_thread(_embed) + return EmbeddingResult( + confidence=1.0, inference_time_ms=(time.perf_counter() - start) * 1000, + model_name=self.model_id, embeddings=embeddings, dimensions=self._embedding_dim, + ) + + def get_model_info(self) -> dict: + info = super().get_model_info() + info.update({"variant": self.model_id, "embedding_dim": self._embedding_dim, + "num_classes": len(self.class_names)}) + return info diff --git a/runtimes/edge/models/gguf_language_model.py b/runtimes/edge/models/gguf_language_model.py new file mode 100644 index 000000000..380fb82f9 --- /dev/null +++ b/runtimes/edge/models/gguf_language_model.py @@ -0,0 +1,1647 @@ +""" +GGUF language model wrapper using llama-cpp. + +Provides the same interface as LanguageModel but uses llama-cpp for +GGUF quantized models, enabling faster inference and lower memory usage. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from collections.abc import AsyncGenerator +from concurrent.futures import ThreadPoolExecutor +from functools import lru_cache +from typing import TYPE_CHECKING + +from utils.context_calculator import get_default_context_size +from utils.context_manager import ContextBudget, ContextManager, ContextUsage +from utils.gguf_metadata_cache import get_gguf_metadata_cached +from utils.gpu_allocator import ( + SPLIT_MODE_LAYER, + SPLIT_MODE_NONE, + InsufficientVRAMError, + get_llama_gpu_params, +) +from utils.model_format import get_gguf_file_path +from utils.token_counter import TokenCounter + +from .base import BaseModel + +if TYPE_CHECKING: + from llamafarm_llama import Llama + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def _is_unified_memory_gpu() -> bool: + """Detect NVIDIA Jetson/Tegra unified memory GPU platforms. + + Jetson devices have unified memory where CPU and GPU share RAM. On these systems, + running inference through ThreadPoolExecutor can cause performance issues due to + thread context switching overhead. Running synchronously avoids this overhead and + provides stability benefits by keeping CUDA operations in predictable thread contexts. + + Supported platforms: + - NVIDIA Jetson Orin (Nano, NX, AGX) + - NVIDIA Jetson Xavier (NX, AGX) + - NVIDIA Jetson TX2, Nano + + Environment variable override: + LLAMAFARM_SYNC_INFERENCE=1 # Force synchronous inference + LLAMAFARM_SYNC_INFERENCE=0 # Force asynchronous inference (ThreadPoolExecutor) + + Returns: + True if synchronous inference should be used (Jetson/Tegra or override) + """ + # Check for environment variable override first + override = os.environ.get("LLAMAFARM_SYNC_INFERENCE", "").lower() + if override in ("1", "true", "yes"): + logger.info("Sync inference ENABLED via LLAMAFARM_SYNC_INFERENCE=1") + return True + if override in ("0", "false", "no"): + logger.info("Sync inference DISABLED via LLAMAFARM_SYNC_INFERENCE=0") + return False + + # Auto-detect: NVIDIA Tegra/Jetson (unified memory iGPU) + try: + if os.path.exists("/proc/device-tree/compatible"): + with open("/proc/device-tree/compatible", "rb") as f: + compatible = f.read().decode("utf-8", errors="ignore").lower() + if "tegra" in compatible or "jetson" in compatible: + logger.info("NVIDIA Jetson/Tegra detected (sync inference enabled)") + return True + # Fallback: check kernel version string + if os.path.exists("/proc/version"): + with open("/proc/version") as f: + if "tegra" in f.read().lower(): + logger.info("NVIDIA Tegra kernel detected (sync inference enabled)") + return True + except Exception as e: + logger.debug(f"Unified memory GPU detection failed: {e}") + + # Apple Silicon and other platforms use async inference (ThreadPoolExecutor) + # which was the original behavior before Jetson optimizations + return False + + +class GGUFLanguageModel(BaseModel): + """Wrapper for GGUF models using llama-cpp. + + This class provides an interface compatible with LanguageModel but uses + llama-cpp for inference with GGUF quantized models. GGUF models + offer: + - 50-75% smaller file sizes (4-bit/8-bit quantization) + - 2-3x faster inference on Apple Silicon (Metal) + - Significantly lower memory requirements + - Optimized CPU inference + + The model is automatically configured for the target device (Metal/CUDA/CPU) + and supports both streaming and non-streaming generation. + """ + + def __init__( + self, + model_id: str, + device: str, + token: str | None = None, + n_ctx: int | None = None, + n_batch: int | None = None, + n_gpu_layers: int | None = None, + n_threads: int | None = None, + flash_attn: bool | None = None, + use_mmap: bool | None = None, + use_mlock: bool | None = None, + cache_type_k: str | None = None, + cache_type_v: str | None = None, + preferred_quantization: str | None = None, + mmproj_path: str | None = None, + auto_detect_mmproj: bool = True, + ): + """Initialize GGUF language model. + + Args: + model_id: HuggingFace model identifier (e.g., "unsloth/Qwen3-0.6B-GGUF") + device: Target device ("cuda", "mps", or "cpu") + token: Optional HuggingFace authentication token for gated models + n_ctx: Optional context window size. If None, will be computed automatically + based on available memory and model defaults. + n_batch: Optional batch size for prompt processing. If None, defaults to 2048. + Critical for memory: lower values (e.g., 512) reduce compute buffer size. + n_gpu_layers: Optional number of layers to offload to GPU. If None, will be + auto-detected based on device. Use -1 for all layers. + n_threads: Optional number of CPU threads. If None, auto-detected. + Set to match CPU core count (e.g., 6 for Jetson Orin Nano). + flash_attn: Optional flag to enable/disable flash attention. If None, + defaults to True for faster inference on supported hardware. + use_mmap: Optional flag for memory-mapped file loading. If None, defaults to False. + False is safer for unified memory platforms (Jetson, Apple Silicon) where + mmap can cause compute graph splits. Set to True for discrete GPUs with + separate VRAM if memory swapping is desired. + use_mlock: Optional flag to lock model in RAM. If None, defaults to False. + Set False on 8GB devices to allow OS memory management. + cache_type_k: Optional KV cache key quantization type (e.g., "q4_0", "q8_0", "f16"). + Using "q4_0" can reduce KV cache memory by ~4x. Critical for + memory-constrained devices like Jetson Orin Nano (8GB shared). + cache_type_v: Optional KV cache value quantization type. Same options as cache_type_k. + Setting both to "q4_0" provides maximum memory savings. + preferred_quantization: Optional quantization preference (e.g., "Q4_K_M", "Q8_0"). + If None, defaults to Q4_K_M. Only downloads the specified + quantization to save disk space. + mmproj_path: Optional path to multimodal projector file for audio/vision models. + If None and auto_detect_mmproj is True, will try to find mmproj + file in the same repository. + auto_detect_mmproj: If True (default), automatically detect and download mmproj + files for multimodal models like Qwen2.5-Omni. + """ + super().__init__(model_id, device, token=token) + self.model_type = "language" + self.supports_streaming = True + self.llama: Llama | None = None + self.requested_n_ctx = self.n_ctx = n_ctx # Store requested value + self.actual_n_ctx: int | None = None # Will be computed during load() + self.requested_n_batch = n_batch # Store requested value (None = default 2048) + self.requested_n_gpu_layers = ( + n_gpu_layers # Store requested value (None = auto) + ) + self.requested_n_threads = n_threads # Store requested value (None = auto) + self.requested_flash_attn = ( + flash_attn # Store requested value (None = default True) + ) + self.requested_use_mmap = ( + use_mmap # Store requested value (None = default False) + ) + self.requested_use_mlock = ( + use_mlock # Store requested value (None = default False) + ) + self.requested_cache_type_k = ( + cache_type_k # Store requested value (None = default f16) + ) + self.requested_cache_type_v = ( + cache_type_v # Store requested value (None = default f16) + ) + self.preferred_quantization = preferred_quantization + self.requested_mmproj_path = mmproj_path # Explicit mmproj path + self.auto_detect_mmproj = auto_detect_mmproj # Auto-detect mmproj files + self._executor = ThreadPoolExecutor(max_workers=1) + + # Context management (initialized during load()) + self._token_counter: TokenCounter | None = None + self._context_manager: ContextManager | None = None + + # Cached GGUF metadata (extracted once during load()) + self._chat_template: str | None = None + self._special_tokens: dict[str, str] | None = None + + # Multimodal support (set during load() if mmproj is loaded) + self._supports_audio: bool = False + self._supports_vision: bool = False + + def _get_available_memory_mb(self) -> int | None: + """Get available system memory in MB for Memory Guard check. + + This helps prevent OOM errors on memory-constrained devices like Jetson + by detecting low memory conditions before attempting to allocate large buffers. + + Returns: + Available memory in MB, or None if unable to determine. + """ + try: + # Try Linux /proc/meminfo first (works on Jetson and most Linux) + with open("/proc/meminfo") as f: + for line in f: + if "MemAvailable" in line: + # Format: "MemAvailable: 1234567 kB" + return int(line.split()[1]) // 1024 + except (FileNotFoundError, PermissionError, OSError): + # /proc/meminfo unavailable (non-Linux or restricted) — try psutil next + logger.debug("Could not read /proc/meminfo, falling back to psutil", exc_info=True) + + # Fallback: try psutil if available + try: + import psutil + + return int(psutil.virtual_memory().available / (1024 * 1024)) + except ImportError: + pass + + # Unable to determine available memory + return None + + async def load(self) -> None: + """Load the GGUF model using llama-cpp. + + This method: + 1. Locates the .gguf file in the HuggingFace cache + 2. Computes optimal context size based on memory and configuration + 3. Configures GPU layers based on the target device + 4. Initializes the llama-cpp Llama instance + 5. Runs initialization in a thread pool (blocking operation) + + Raises: + FileNotFoundError: If no .gguf file found in model repository + Exception: If model loading fails + """ + + # Re-create executor if it was destroyed by unload() + # CRITICAL: Single-threaded executor prevents concurrent access to non-thread-safe llama.cpp + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=1) + + logger.info(f"Loading GGUF model: {self.model_id}") + + # Get path to .gguf file in HF cache + # This will intelligently select and download only the preferred quantization + gguf_path = get_gguf_file_path( + self.model_id, + self.token, + preferred_quantization=self.preferred_quantization, + ) + + # On Windows, convert backslashes to forward slashes for llama.cpp compatibility + # The underlying C library can have issues with Windows-style paths + if sys.platform == "win32": + gguf_path = gguf_path.replace("\\", "/") + + logger.info(f"GGUF file located at: {gguf_path}") + + # Store path for later use (e.g., Jinja2 template extraction) + self._gguf_path = gguf_path + + # Compute optimal context size + self.actual_n_ctx, warnings = get_default_context_size( + model_id=self.model_id, + gguf_path=gguf_path, + device=self.device, + config_n_ctx=self.requested_n_ctx, + ) + + # Log warnings to stderr + for warning in warnings: + logger.warning(warning) + + logger.info(f"Using context size: {self.actual_n_ctx}") + + # Configure GPU layers for llama.cpp + # Use explicitly requested value if provided, otherwise auto-detect + from utils.device import get_gguf_gpu_layers + + if self.requested_n_gpu_layers is not None: + n_gpu_layers = self.requested_n_gpu_layers + logger.info(f"Using configured n_gpu_layers: {n_gpu_layers}") + else: + n_gpu_layers = get_gguf_gpu_layers() + logger.info(f"Auto-detected n_gpu_layers: {n_gpu_layers}") + + # GPU allocation: select optimal GPU(s) based on free VRAM + # This prevents OOM crashes on multi-GPU systems by routing models + # to the GPU with the most free VRAM (split_mode=NONE) instead of + # splitting across all GPUs (llama.cpp's default split_mode=LAYER) + gpu_params = {} + try: + metadata = get_gguf_metadata_cached(gguf_path) + gpu_params = get_llama_gpu_params( + model_size_bytes=metadata.file_size_bytes, + n_ctx=self.actual_n_ctx, + n_gpu_layers=n_gpu_layers, + total_layers=metadata.n_layer, + n_layer=metadata.n_layer, + n_head_kv=metadata.n_head_kv, + head_k_size=metadata.head_k_size, + head_v_size=metadata.head_v_size, + ) + if gpu_params: + gpu_idx = gpu_params.get("gpu_index") + logger.info( + f"GPU allocation: main_gpu={gpu_params.get('main_gpu')}, " + f"split_mode={gpu_params.get('split_mode')}, " + f"gpu_index={gpu_idx}" + ) + # Re-compute context size using the allocated GPU memory. + # - Single-GPU (SPLIT_MODE_NONE): use the specific GPU's + # free VRAM via gpu_index. + # - Multi-GPU (SPLIT_MODE_LAYER): use the combined free VRAM + # across all participating devices, since both model weights + # and KV cache are distributed proportionally. + split_mode = gpu_params.get("split_mode") + if split_mode == SPLIT_MODE_NONE and gpu_idx is not None: + new_n_ctx, new_warnings = get_default_context_size( + model_id=self.model_id, + gguf_path=gguf_path, + device=self.device, + config_n_ctx=self.requested_n_ctx, + gpu_index=gpu_idx, + ) + elif split_mode == SPLIT_MODE_LAYER: + new_n_ctx, new_warnings = get_default_context_size( + model_id=self.model_id, + gguf_path=gguf_path, + device=self.device, + config_n_ctx=self.requested_n_ctx, + available_memory_override=gpu_params["total_free_vram"], + ) + else: + new_n_ctx, new_warnings = self.actual_n_ctx, [] + + if new_n_ctx != self.actual_n_ctx: + label = ( + f"GPU {gpu_idx}" + if split_mode == SPLIT_MODE_NONE + else "multi-GPU split" + ) + logger.info( + f"Context size adjusted for {label}: " + f"{self.actual_n_ctx} -> {new_n_ctx}" + ) + self.actual_n_ctx = new_n_ctx + for w in new_warnings: + logger.warning(w) + + # Context changed — re-run allocation so tensor_split + # and per-device feasibility reflect the actual KV + # cache size. Without this the stale split computed + # for the old n_ctx can OOM on a weaker GPU. + if split_mode == SPLIT_MODE_LAYER: + gpu_params = get_llama_gpu_params( + model_size_bytes=metadata.file_size_bytes, + n_ctx=self.actual_n_ctx, + n_gpu_layers=n_gpu_layers, + total_layers=metadata.n_layer, + n_layer=metadata.n_layer, + n_head_kv=metadata.n_head_kv, + head_k_size=metadata.head_k_size, + head_v_size=metadata.head_v_size, + ) + logger.info( + "Re-allocated GPUs for updated context: " + f"split_mode={gpu_params.get('split_mode')}, " + f"main_gpu={gpu_params.get('main_gpu')}" + ) + else: + logger.debug("No CUDA GPUs detected, using default GPU allocation") + except InsufficientVRAMError as e: + if e.gpu_details: + logger.error(f"GPU allocation failed:\n{e.gpu_details}") + else: + logger.error(f"GPU allocation failed: {e}") + raise RuntimeError(str(e)) from e + except Exception as e: + logger.warning(f"GPU allocation failed, using defaults: {e}") + + # Configure batch size (critical for memory on constrained devices) + # Default 2048 for fast prompt processing, but lower values reduce memory + n_batch = self.requested_n_batch if self.requested_n_batch is not None else 2048 + + # Memory Guard: Check available memory and reduce n_batch if needed + # This prevents "Error 12" (CUDA OOM) on memory-constrained devices like Jetson + available_mb = self._get_available_memory_mb() + if available_mb is not None and available_mb < 3000 and n_batch > 512: + logger.warning( + f"Low memory detected ({available_mb}MB available). " + f"Reducing n_batch from {n_batch} to 512 to prevent OOM." + ) + n_batch = 512 + + logger.info(f"Using n_batch: {n_batch}") + + # Configure thread count (None = auto-detect in Llama class) + n_threads = self.requested_n_threads + if n_threads is not None: + logger.info(f"Using configured n_threads: {n_threads}") + + # Configure flash attention (default True for faster inference) + flash_attn = ( + self.requested_flash_attn if self.requested_flash_attn is not None else True + ) + logger.info(f"Using flash_attn: {flash_attn}") + + # Configure memory mapping - default False for unified memory platforms (Jetson, Apple Silicon) + # Memory mapping can cause compute graph splits on unified memory systems where CPU and GPU + # share the same physical memory. This results in suboptimal performance. For discrete GPUs + # with separate VRAM, mmap may be beneficial for memory-constrained scenarios. + use_mmap = ( + self.requested_use_mmap if self.requested_use_mmap is not None else False + ) + logger.info(f"Using use_mmap: {use_mmap}") + + # Configure memory locking (default False to allow OS memory management) + use_mlock = ( + self.requested_use_mlock if self.requested_use_mlock is not None else False + ) + logger.info(f"Using use_mlock: {use_mlock}") + + # Configure KV cache quantization (None = default f16, use q4_0 for memory savings) + cache_type_k = self.requested_cache_type_k + cache_type_v = self.requested_cache_type_v + if cache_type_k is not None: + logger.info(f"Using cache_type_k: {cache_type_k}") + if cache_type_v is not None: + logger.info(f"Using cache_type_v: {cache_type_v}") + + # Detect or use explicit mmproj path for multimodal models + mmproj_path = self.requested_mmproj_path + if mmproj_path is None and self.auto_detect_mmproj: + try: + from llamafarm_common import get_mmproj_file_path + + mmproj_path = get_mmproj_file_path(self.model_id, self.token) + if mmproj_path: + logger.info(f"Auto-detected mmproj file: {mmproj_path}") + except Exception as e: + logger.debug(f"mmproj auto-detection failed: {e}") + + # Load model using llama-cpp + # Run in thread pool since Llama() initialization is blocking + loop = asyncio.get_running_loop() + + def _load_model(): + import os + + try: + from llamafarm_llama import Llama + except ImportError as e: + raise ImportError( + "llamafarm-llama is required for GGUF models but is not installed. " + "Install it with: pip install llamafarm-llama" + ) from e + + # Verify resolved path stays within the HuggingFace cache directory + from huggingface_hub.constants import HF_HUB_CACHE + + resolved = os.path.realpath(gguf_path) + hf_cache_resolved = os.path.realpath(HF_HUB_CACHE) + if not resolved.startswith(hf_cache_resolved + os.sep): + raise ValueError( + f"GGUF path outside HuggingFace cache: {gguf_path}" + ) + + # Verify file exists and is readable before attempting to load + if not os.path.exists(resolved): + raise FileNotFoundError(f"GGUF file not found: {gguf_path}") + if not os.access(resolved, os.R_OK): + raise PermissionError(f"GGUF file not readable: {gguf_path}") + + file_size_mb = os.path.getsize(resolved) / (1024 * 1024) + logger.info(f"Loading GGUF file ({file_size_mb:.1f} MB): {gguf_path}") + + try: + # Build GPU-specific kwargs from allocation + gpu_kwargs = {} + if gpu_params.get("main_gpu") is not None: + gpu_kwargs["main_gpu"] = gpu_params["main_gpu"] + if gpu_params.get("split_mode") is not None: + gpu_kwargs["split_mode"] = gpu_params["split_mode"] + if gpu_params.get("tensor_split") is not None: + gpu_kwargs["tensor_split"] = gpu_params["tensor_split"] + + return Llama( + model_path=gguf_path, + mmproj_path=mmproj_path, # Multimodal projector for audio/vision + n_ctx=self.actual_n_ctx, # Use computed context size + n_batch=n_batch, # Batch size for prompt processing + n_gpu_layers=n_gpu_layers, # GPU layer offloading + n_threads=n_threads, # CPU threads (None = auto) + flash_attn=flash_attn, # Flash attention optimization + use_mmap=use_mmap, # Memory-mapped file loading + use_mlock=use_mlock, # Lock model in RAM + cache_type_k=cache_type_k, # KV cache key quantization + cache_type_v=cache_type_v, # KV cache value quantization + verbose=False, # Disable verbose logging (managed by ggml_logging) + seed=-1, # Random seed (-1 = random) + **gpu_kwargs, + ) + except ValueError as e: + # Provide more helpful error message for common issues + error_msg = str(e) + if "Failed to load model from file" in error_msg: + logger.error( + f"llama.cpp failed to load model. This can be caused by:\n" + f" 1. Corrupted GGUF file - try deleting and re-downloading\n" + f" 2. Incompatible llama-cpp binary - try reinstalling\n" + f" 3. Unsupported GGUF format version\n" + f" File: {gguf_path}\n" + f" Size: {file_size_mb:.1f} MB\n" + f" Context: {self.actual_n_ctx}" + ) + raise + + try: + # On unified memory platforms (Jetson Tegra, Apple Silicon), load model + # synchronously to ensure GPU context is created optimally and avoid + # thread context switching overhead in shared memory architecture + if _is_unified_memory_gpu(): + logger.info( + "Loading model synchronously (unified memory GPU optimization)" + ) + self.llama = _load_model() + else: + self.llama = await loop.run_in_executor(self._executor, _load_model) + + # Initialize context management + self._token_counter = TokenCounter(self.llama) + budget = ContextBudget.from_context_size(self.actual_n_ctx) + self._context_manager = ContextManager(self._token_counter, budget) + + # Pre-extract and cache GGUF metadata for chat template rendering + # This avoids re-reading the large GGUF file on every request + try: + from utils.jinja_tools import ( + get_chat_template_from_gguf, + get_special_tokens_from_gguf, + supports_native_tools, + ) + + self._chat_template = get_chat_template_from_gguf(gguf_path) + if self._chat_template: + has_tools = supports_native_tools(self._chat_template) + logger.info( + f"Chat template cached ({len(self._chat_template)} chars), " + f"supports_native_tools={has_tools}" + ) + else: + logger.debug("No chat template found in GGUF metadata") + + self._special_tokens = get_special_tokens_from_gguf(gguf_path) + logger.debug( + f"Special tokens cached: bos='{self._special_tokens.get('bos_token', '')}', " + f"eos='{self._special_tokens.get('eos_token', '')}'" + ) + except Exception as e: + logger.warning(f"Failed to cache GGUF metadata: {e}") + self._chat_template = None + self._special_tokens = None + + # Check multimodal capabilities + if self.llama and hasattr(self.llama, "supports_audio"): + self._supports_audio = self.llama.supports_audio + self._supports_vision = getattr(self.llama, "supports_vision", False) + if self._supports_audio or self._supports_vision: + logger.info( + f"Multimodal capabilities: audio={self._supports_audio}, " + f"vision={self._supports_vision}" + ) + + logger.info( + f"GGUF model loaded successfully on {self.device} " + f"with {n_gpu_layers} GPU layers and context size {self.actual_n_ctx}" + ) + except Exception: + # Clean up executor if load fails to prevent resource leak + if hasattr(self, "_executor"): + self._executor.shutdown(wait=False) + raise + + @property + def supports_audio(self) -> bool: + """Whether this model supports direct audio input. + + Returns True if the model was loaded with a multimodal projector + that supports audio processing (e.g., Qwen2.5-Omni). + """ + return self._supports_audio + + @property + def supports_vision(self) -> bool: + """Whether this model supports direct image/vision input. + + Returns True if the model was loaded with a multimodal projector + that supports vision processing. + """ + return self._supports_vision + + def format_messages(self, messages: list[dict]) -> str: + """Format chat messages into a prompt string. + + Converts OpenAI-style chat messages into a single prompt string + suitable for the model. Uses a simple template format. + + Args: + messages: List of message dicts with 'role' and 'content' keys + + Returns: + Formatted prompt string + + Examples: + >>> messages = [ + ... {"role": "system", "content": "You are helpful"}, + ... {"role": "user", "content": "Hello"} + ... ] + >>> model.format_messages(messages) + 'System: You are helpful\\nUser: Hello\\nAssistant:' + """ + prompt_parts = [] + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + + if role == "system": + prompt_parts.append(f"System: {content}") + elif role == "user": + prompt_parts.append(f"User: {content}") + elif role == "assistant": + prompt_parts.append(f"Assistant: {content}") + + # Add final prompt for assistant response + prompt_parts.append("Assistant:") + return "\n".join(prompt_parts) + + @property + def token_counter(self) -> TokenCounter | None: + """Get the token counter for this model.""" + return self._token_counter + + @property + def context_manager(self) -> ContextManager | None: + """Get the context manager for this model.""" + return self._context_manager + + def count_tokens(self, text: str) -> int: + """Count tokens in text using the model's tokenizer. + + Args: + text: Text to count tokens for. + + Returns: + Number of tokens. + + Raises: + RuntimeError: If model not loaded. + """ + if self._token_counter is None: + raise RuntimeError("Model not loaded. Call load() first.") + return self._token_counter.count_tokens(text) + + def validate_context(self, messages: list[dict]) -> ContextUsage: + """Validate messages fit within context and return usage info. + + Args: + messages: List of chat messages to validate. + + Returns: + ContextUsage with token counts and overflow status. + + Raises: + RuntimeError: If model not loaded. + """ + if self._context_manager is None: + raise RuntimeError("Model not loaded. Call load() first.") + return self._context_manager.validate_messages(messages) + + def _render_with_jinja2( + self, + messages: list[dict], + tools: list[dict], + ) -> str | None: + """Try to render messages with tools using Jinja2 template. + + This uses the model's native chat template (cached from GGUF metadata) to render + the prompt with tool definitions, which produces better results for models + that were trained with native tool calling support. + + Args: + messages: List of message dicts with 'role' and 'content' keys + tools: List of tool definitions in OpenAI format + + Returns: + Rendered prompt string if the model supports native tools, None otherwise. + """ + # Use cached template (extracted once during load()) + template = self._chat_template + if not template: + logger.debug("Jinja2 rendering skipped: no chat template cached") + return None + + try: + from utils.jinja_tools import ( + render_chat_with_tools, + supports_native_tools, + ) + + has_tools = supports_native_tools(template) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Using cached chat template ({len(template)} chars), " + f"supports_native_tools={has_tools}" + ) + # Log first 500 chars of template for debugging + logger.debug(f"Template preview: {template[:500]}...") + + if not has_tools: + logger.debug( + "Jinja2 rendering skipped: template does not support native tools " + "('tools' variable not found in template)" + ) + return None + + # Use cached special tokens + special_tokens = self._special_tokens or {} + + # Debug log tools being used in Jinja2 path + if logger.isEnabledFor(logging.DEBUG): + import json + + tool_names = [ + t.get("function", {}).get("name", "unknown") for t in tools + ] + logger.debug(f"Tools provided (Jinja2 path): {tool_names}") + logger.debug(f"Full tool definitions:\n{json.dumps(tools, indent=2)}") + + # Render the template with tools + prompt = render_chat_with_tools( + template=template, + messages=messages, + tools=tools, + add_generation_prompt=True, + bos_token=special_tokens.get("bos_token", ""), + eos_token=special_tokens.get("eos_token", ""), + ) + + logger.debug( + f"Rendered prompt with Jinja2 native tool support " + f"({len(prompt)} chars, {len(tools)} tools)" + ) + return prompt + + except Exception as e: + logger.debug(f"Jinja2 tool rendering failed, will use fallback: {e}") + return None + + def _prepare_messages_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + ) -> list[dict]: + """Prepare messages with tool definitions using prompt injection. + + This is the fallback approach when Jinja2 rendering is not available. + Tools are injected into the system message using XML format. + + Args: + messages: List of message dicts with 'role' and 'content' keys + tools: Optional list of tool definitions in OpenAI format + tool_choice: Tool choice strategy: + - None or "auto": Model may call tools (default) + - "none": Model should not call tools + - "required": Model must call at least one tool + - {"type": "function", "function": {"name": "X"}}: Must call specific function + + Returns: + Messages with tools injected (if tools provided) + """ + if not tools: + return messages + + # Debug log tools and tool_choice + if logger.isEnabledFor(logging.DEBUG): + import json + + tool_names = [t.get("function", {}).get("name", "unknown") for t in tools] + logger.debug(f"Tools provided: {tool_names}") + logger.debug(f"Tool choice: {tool_choice}") + logger.debug(f"Full tool definitions:\n{json.dumps(tools, indent=2)}") + + # Inject tools into messages using prompt-based approach + from utils.tool_calling import inject_tools_into_messages + + logger.debug( + f"Using prompt-based tool injection with tool_choice={tool_choice}" + ) + return inject_tools_into_messages(messages, tools, tool_choice=tool_choice) + + def prepare_messages_for_context_validation( + self, + messages: list[dict], + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + ) -> tuple[list[dict], bool, str | None]: + """Prepare message shape for context checks and indicate generation strategy. + + Returns: + Tuple of (messages_for_context, already_injected, native_rendered_prompt). + - already_injected=True means tool content is already present in returned + messages and should not be injected again during generation. + - native_rendered_prompt is populated when native Jinja2 tool rendering + is used for generation. + """ + if not tools: + return messages, False, None + + native_rendered_prompt = self._render_with_jinja2(messages, tools) + if native_rendered_prompt is not None: + # Context validation should count the exact prompt that will be sent via + # create_completion() for native tool-capable models. + return messages, False, native_rendered_prompt + + return self._prepare_messages_with_tools(messages, tools, tool_choice), True, None + + async def _generate_from_prompt( + self, + prompt: str, + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None, + thinking_budget: int | None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> str: + """Generate completion from a pre-formatted prompt string. + + This is used when Jinja2 rendering produces a prompt with native tool support. + + Args: + prompt: Pre-formatted prompt string + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Nucleus sampling threshold + stop: List of stop sequences + thinking_budget: Maximum tokens for thinking + kv_cache_data: Serialized KV cache state to restore + kv_cache_tokens: Number of tokens in the cached state + + Returns: + Generated text as a string + """ + assert self.llama is not None, "Model not loaded" + + loop = asyncio.get_running_loop() + + # Capture llama reference for nested function (type checker can't see through closures) + llama = self.llama + + def _generate(): + try: + # Set up logits processor for thinking budget if specified + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + llama, max_thinking_tokens=thinking_budget + ) + + # Use create_completion for raw prompts (no chat template applied) + return llama.create_completion( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ) + except Exception as e: + logger.error( + f"Error during llama-cpp completion: {e}", + exc_info=True, + ) + raise RuntimeError(f"Completion failed: {e}") from e + + try: + # On unified memory platforms (Jetson, Apple Silicon), run synchronously + # to avoid ThreadPoolExecutor overhead in shared memory architecture + if _is_unified_memory_gpu(): + result = _generate() + else: + result = await loop.run_in_executor(self._executor, _generate) + content = result["choices"][0]["message"]["content"] + return content.strip() if content else "" + except Exception as e: + logger.error(f"Error extracting completion result: {e}", exc_info=True) + raise ValueError(f"Unexpected result from completion: {e}") from e + + async def generate( + self, + messages: list[dict], + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + thinking_budget: int | None = None, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> str: + """Generate chat completion (non-streaming). + + For tool calling, this method first tries to use the model's native Jinja2 + template with tool support. If the model doesn't support native tools, + falls back to prompt-based tool injection. + + Args: + messages: List of message dicts with 'role' and 'content' keys + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature (0.0 = greedy, higher = more random) + top_p: Nucleus sampling threshold + stop: List of stop sequences to end generation + thinking_budget: Maximum tokens for thinking before forcing + tools: Optional list of tool definitions in OpenAI format + tool_choice: Optional tool choice strategy ("auto", "none", "required") + + Returns: + Generated text as a string + + Raises: + AssertionError: If model not loaded + """ + assert self.llama is not None, "Model not loaded. Call load() first." + + max_tokens = max_tokens or 512 + logger.info(f"[TIMING] generate() start, max_tokens={max_tokens}") + + # Try Jinja2 native tool rendering first (if tools provided) + if tools: + jinja2_prompt = self._render_with_jinja2(messages, tools) + if jinja2_prompt is not None: + # Debug log the full prompt being sent to the LLM + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"[generate] Final prompt (Jinja2 rendered, {len(jinja2_prompt)} chars):\n" + f"{'=' * 60}\n{jinja2_prompt}\n{'=' * 60}" + ) + # Use the pre-formatted prompt directly + return await self._generate_from_prompt( + prompt=jinja2_prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + thinking_budget=thinking_budget, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ) + + # Fallback: use prompt injection + chat completion + prepared_messages = self._prepare_messages_with_tools( + messages, tools, tool_choice + ) + + # Debug log the prepared messages (prompt injection path) + if logger.isEnabledFor(logging.DEBUG): + import json + + logger.debug( + f"[generate] Prepared messages ({len(prepared_messages)} messages):\n" + f"{'=' * 60}\n{json.dumps(prepared_messages, indent=2)}\n{'=' * 60}" + ) + loop = asyncio.get_running_loop() + + def _generate(): + try: + # Set up logits processor for thinking budget if specified + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + self.llama, max_thinking_tokens=thinking_budget + ) + + # Use create_chat_completion which applies the model's chat template + return self.llama.create_chat_completion( + messages=prepared_messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ) + except Exception as e: + logger.error( + f"Error during llama-cpp chat completion: {e}", + exc_info=True, + ) + raise RuntimeError(f"Chat completion failed: {e}") from e + + try: + # On unified memory platforms (Jetson, Apple Silicon), run synchronously + # to avoid ThreadPoolExecutor overhead in shared memory architecture. + # This provides both performance and stability benefits. + if _is_unified_memory_gpu(): + result = _generate() + else: + result = await loop.run_in_executor(self._executor, _generate) + content = result["choices"][0]["message"]["content"] + return content.strip() if content else "" + except Exception as e: + logger.error(f"Error extracting chat completion result: {e}", exc_info=True) + raise ValueError(f"Unexpected result from chat completion: {e}") from e + + async def generate_with_logprobs( + self, + messages: list[dict], + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + thinking_budget: int | None = None, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + top_logprobs: int | None = None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> dict: + """Generate chat completion and include raw logprobs payload when supported.""" + if self.llama is None: + raise RuntimeError("Model not loaded. Call load() first.") + + max_tokens = max_tokens or 512 + + # Keep behavior aligned with generate(): if tools are provided and the model + # supports native Jinja2 rendering, use that path (no logprobs in this path yet). + if tools: + jinja2_prompt = self._render_with_jinja2(messages, tools) + if jinja2_prompt is not None: + content = await self._generate_from_prompt( + prompt=jinja2_prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + thinking_budget=thinking_budget, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ) + return {"content": content, "logprobs": None} + + prepared_messages = self._prepare_messages_with_tools( + messages, tools, tool_choice + ) + + loop = asyncio.get_running_loop() + + def _generate(): + try: + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + self.llama, max_thinking_tokens=thinking_budget + ) + + kwargs = { + "messages": prepared_messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stop": stop or [], + "logits_processor": logits_processor, + "logprobs": True, + "kv_cache_data": kv_cache_data, + "kv_cache_tokens": kv_cache_tokens, + } + if top_logprobs is not None: + kwargs["top_logprobs"] = top_logprobs + + return self.llama.create_chat_completion(**kwargs) + except Exception as e: + logger.error( + f"Error during llama-cpp chat completion (logprobs): {e}", + exc_info=True, + ) + raise RuntimeError("Chat completion failed") from e + + if _is_unified_memory_gpu(): + result = _generate() + else: + result = await loop.run_in_executor(self._executor, _generate) + + try: + choice = result["choices"][0] + message = choice["message"] + content = message.get("content") + if content is None: + content = choice.get("text", "") + except (KeyError, IndexError, TypeError) as e: + logger.error(f"Error extracting chat completion result: {e}", exc_info=True) + raise ValueError(f"Unexpected result from chat completion: {e}") from e + + return { + "content": content.strip() if isinstance(content, str) else "", + "logprobs": choice.get("logprobs") if isinstance(choice, dict) else None, + } + + async def _stream_from_prompt( + self, + prompt: str, + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None, + thinking_budget: int | None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> AsyncGenerator[str, None]: + """Stream completion from a pre-formatted prompt string. + + This is used when Jinja2 rendering produces a prompt with native tool support. + + Args: + prompt: Pre-formatted prompt string + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Nucleus sampling threshold + stop: List of stop sequences + thinking_budget: Maximum tokens for thinking + kv_cache_data: Serialized KV cache state to restore + kv_cache_tokens: Number of tokens in the cached state + + Yields: + Generated text tokens as strings + """ + assert self.llama is not None, "Model not loaded" + + # Capture llama reference for nested function (type checker can't see through closures) + llama = self.llama + + # On Jetson/Tegra, stream synchronously to avoid thread context switching + # overhead in unified memory architecture + if _is_unified_memory_gpu(): + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + llama, max_thinking_tokens=thinking_budget + ) + + for chunk in llama.create_completion( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + await asyncio.sleep(0) + return + + # Async path: use ThreadPoolExecutor (Apple Silicon, discrete GPUs, CPU) + queue: asyncio.Queue[str | Exception | None] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _generate_stream(): + """Run completion in separate thread.""" + try: + thinking_tokens = 0 + in_thinking = False + thinking_ended = False + accumulated_text = "" + + # Set up logits processor for thinking budget enforcement + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + llama, max_thinking_tokens=thinking_budget + ) + + for chunk in llama.create_completion( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + accumulated_text += content + + # Track thinking state + if "" in accumulated_text.lower() and not in_thinking: + in_thinking = True + if "" in accumulated_text.lower(): + thinking_ended = True + in_thinking = False + + # Count thinking tokens + if in_thinking and not thinking_ended: + thinking_tokens += 1 + + future = asyncio.run_coroutine_threadsafe( + queue.put(content), loop + ) + future.result() + except Exception as e: + logger.error(f"Error in GGUF completion stream: {e}", exc_info=True) + future = asyncio.run_coroutine_threadsafe(queue.put(e), loop) + future.result() + finally: + future = asyncio.run_coroutine_threadsafe(queue.put(None), loop) + future.result() + + loop.run_in_executor(self._executor, _generate_stream) + + # Yield tokens as they arrive, propagate exceptions + while True: + item = await queue.get() + if item is None: + break + elif isinstance(item, Exception): + raise item + else: + yield item + + async def generate_stream( + self, + messages: list[dict], + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + thinking_budget: int | None = None, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> AsyncGenerator[str, None]: + """Generate chat completion with streaming (async generator). + + For tool calling, this method first tries to use the model's native Jinja2 + template with tool support. If the model doesn't support native tools, + falls back to prompt-based tool injection. + + Thinking budget is enforced via logits processor, which forces the model + to generate when the budget is reached. + + Args: + messages: List of message dicts with 'role' and 'content' keys + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature (0.0 = greedy, higher = more random) + top_p: Nucleus sampling threshold + stop: List of stop sequences to end generation + thinking_budget: Maximum tokens for thinking before forcing + tools: Optional list of tool definitions in OpenAI format + tool_choice: Optional tool choice strategy ("auto", "none", "required") + + Yields: + Generated text tokens as strings + + Raises: + AssertionError: If model not loaded + """ + assert self.llama is not None, "Model not loaded. Call load() first." + + max_tokens = max_tokens or 512 + + # Try Jinja2 native tool rendering first (if tools provided) + if tools: + jinja2_prompt = self._render_with_jinja2(messages, tools) + if jinja2_prompt is not None: + # Debug log the full prompt being sent to the LLM + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"[generate_stream] Final prompt (Jinja2 rendered, {len(jinja2_prompt)} chars):\n" + f"{'=' * 60}\n{jinja2_prompt}\n{'=' * 60}" + ) + # Use the pre-formatted prompt directly + async for token in self._stream_from_prompt( + prompt=jinja2_prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + thinking_budget=thinking_budget, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ): + yield token + return + + # Fallback: use prompt injection + chat completion + prepared_messages = self._prepare_messages_with_tools( + messages, tools, tool_choice + ) + + # Debug log the prepared messages (prompt injection path) + if logger.isEnabledFor(logging.DEBUG): + import json + + logger.debug( + f"[generate_stream] Prepared messages ({len(prepared_messages)} messages):\n" + f"{'=' * 60}\n{json.dumps(prepared_messages, indent=2)}\n{'=' * 60}" + ) + + # On Jetson/Tegra, stream synchronously to avoid thread context switching + # overhead in unified memory architecture + if _is_unified_memory_gpu(): + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + self.llama, max_thinking_tokens=thinking_budget + ) + + for chunk in self.llama.create_chat_completion( + messages=prepared_messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + await asyncio.sleep(0) + return + + # Async path: use ThreadPoolExecutor (Apple Silicon, discrete GPUs, CPU) + queue: asyncio.Queue[str | Exception | None] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _generate_stream(): + """Run chat completion in separate thread.""" + try: + thinking_tokens = 0 + in_thinking = False + thinking_ended = False + accumulated_text = "" + + # Set up logits processor for thinking budget enforcement + logits_processor = None + if thinking_budget is not None: + from utils.thinking import ThinkingBudgetProcessor + + logits_processor = ThinkingBudgetProcessor( + self.llama, max_thinking_tokens=thinking_budget + ) + + for chunk in self.llama.create_chat_completion( + messages=prepared_messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + logits_processor=logits_processor, + kv_cache_data=kv_cache_data, + kv_cache_tokens=kv_cache_tokens, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + accumulated_text += content + + # Track thinking state + if "" in accumulated_text.lower() and not in_thinking: + in_thinking = True + if "" in accumulated_text.lower(): + thinking_ended = True + in_thinking = False + + # Count thinking tokens + if in_thinking and not thinking_ended: + thinking_tokens += 1 + + future = asyncio.run_coroutine_threadsafe( + queue.put(content), loop + ) + future.result() + except Exception as e: + logger.error(f"Error in GGUF chat stream: {e}", exc_info=True) + future = asyncio.run_coroutine_threadsafe(queue.put(e), loop) + future.result() + finally: + future = asyncio.run_coroutine_threadsafe(queue.put(None), loop) + future.result() + + loop.run_in_executor(self._executor, _generate_stream) + + # Yield tokens as they arrive, propagate exceptions + while True: + item = await queue.get() + if item is None: + break + elif isinstance(item, Exception): + raise item + else: + yield item + + async def generate_with_audio( + self, + messages: list[dict], + audio_data: bytes, + audio_format: str = "wav", + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + ) -> str: + """Generate chat completion with audio input (non-streaming). + + This method uses the model's native multimodal capabilities to process + audio input directly without STT transcription, enabling audio-to-text + generation for models like Qwen2.5-Omni. + + Args: + messages: List of message dicts. Audio marker in user message content + will be replaced with encoded audio embeddings. + audio_data: Raw audio bytes (WAV, MP3, or PCM format) + audio_format: Format of audio_data ("wav", "mp3", or "pcm") + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature + top_p: Nucleus sampling threshold + stop: List of stop sequences + + Returns: + Generated text as a string + + Raises: + RuntimeError: If model doesn't support audio input + AssertionError: If model not loaded + """ + if not self._supports_audio: + raise RuntimeError( + f"Model {self.model_id} does not support audio input. " + "Load with mmproj_path for audio-capable models like Qwen2.5-Omni." + ) + + assert self.llama is not None, "Model not loaded. Call load() first." + + max_tokens = max_tokens or 512 + loop = asyncio.get_running_loop() + + def _generate(): + try: + return self.llama.create_chat_completion_with_audio( + messages=messages, + audio_data=audio_data, + audio_format=audio_format, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + ) + except Exception as e: + logger.error(f"Error during audio chat completion: {e}", exc_info=True) + raise RuntimeError(f"Audio chat completion failed: {e}") from e + + try: + # On Jetson/Tegra, run synchronously to avoid thread context switching overhead + if _is_unified_memory_gpu(): + result = _generate() + else: + result = await loop.run_in_executor(self._executor, _generate) + content = result["choices"][0]["message"]["content"] + return content.strip() if content else "" + except Exception as e: + logger.error( + f"Error extracting audio completion result: {e}", exc_info=True + ) + raise ValueError(f"Unexpected result from audio completion: {e}") from e + + async def generate_stream_with_audio( + self, + messages: list[dict], + audio_data: bytes, + audio_format: str = "wav", + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + ) -> AsyncGenerator[str, None]: + """Generate chat completion with audio input (streaming). + + This method uses the model's native multimodal capabilities to process + audio input directly and streams the response token by token. + + Args: + messages: List of message dicts with audio markers + audio_data: Raw audio bytes (WAV, MP3, or PCM format) + audio_format: Format of audio_data ("wav", "mp3", or "pcm") + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature + top_p: Nucleus sampling threshold + stop: List of stop sequences + + Yields: + Generated text tokens as strings + + Raises: + RuntimeError: If model doesn't support audio input + AssertionError: If model not loaded + """ + if not self._supports_audio: + raise RuntimeError( + f"Model {self.model_id} does not support audio input. " + "Load with mmproj_path for audio-capable models like Qwen2.5-Omni." + ) + + assert self.llama is not None, "Model not loaded. Call load() first." + + max_tokens = max_tokens or 512 + + # On Jetson/Tegra, stream synchronously to avoid thread context switching overhead + if _is_unified_memory_gpu(): + for chunk in self.llama.create_chat_completion_with_audio( + messages=messages, + audio_data=audio_data, + audio_format=audio_format, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + await asyncio.sleep(0) + return + + # Async path: use ThreadPoolExecutor (Apple Silicon, discrete GPUs, CPU) + queue: asyncio.Queue[str | Exception | None] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _generate_stream(): + try: + for chunk in self.llama.create_chat_completion_with_audio( + messages=messages, + audio_data=audio_data, + audio_format=audio_format, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop or [], + stream=True, + ): + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + future = asyncio.run_coroutine_threadsafe( + queue.put(content), loop + ) + future.result() + except Exception as e: + logger.error(f"Error in audio chat stream: {e}", exc_info=True) + future = asyncio.run_coroutine_threadsafe(queue.put(e), loop) + future.result() + finally: + future = asyncio.run_coroutine_threadsafe(queue.put(None), loop) + future.result() + + loop.run_in_executor(self._executor, _generate_stream) + + while True: + item = await queue.get() + if item is None: + break + elif isinstance(item, Exception): + raise item + else: + yield item + + async def unload(self) -> None: + """Unload GGUF model and free resources.""" + logger.info(f"Unloading GGUF language model: {self.model_id}") + + # Clear llama-cpp instance + self.llama = None + + # Reset multimodal flags to prevent use-after-free + # If these remain True after unload, callers checking supports_audio/supports_vision + # would see stale values and might attempt to use the freed model + self._supports_audio = False + self._supports_vision = False + + # Shutdown thread pool executor + if hasattr(self, "_executor"): + self._executor.shutdown(wait=True, cancel_futures=True) + self._executor = None + + logger.info(f"GGUF language model unloaded: {self.model_id}") + + def __del__(self): + """Cleanup thread pool executor on deletion.""" + if getattr(self, "_executor", None) is not None: + self._executor.shutdown(wait=False) diff --git a/runtimes/edge/models/hailo_model.py b/runtimes/edge/models/hailo_model.py new file mode 100644 index 000000000..67ab8f544 --- /dev/null +++ b/runtimes/edge/models/hailo_model.py @@ -0,0 +1,414 @@ +"""Hailo-10H YOLO detection model. + +Uses the hailo_platform Python API to run YOLO inference on the Hailo-10H +AI accelerator via pre-compiled .hef models from the Hailo Model Zoo. + +The .hef models include built-in NMS, so the output is already decoded +into bounding boxes, class IDs, and confidence scores. + +Requires: +- Hailo-10H PCIe device (/dev/hailo0) +- hailort Python wheel (provides hailo_platform) +- Pre-compiled .hef model files (e.g., yolov11n.hef) +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from pathlib import Path +from typing import Any + +import numpy as np + +from .vision_base import DetectionBox, DetectionModel, DetectionResult + +logger = logging.getLogger(__name__) + +# COCO class names (80 classes) — standard for YOLO models from Hailo Model Zoo +COCO_CLASS_NAMES = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", + "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", + "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", + "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", + "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", + "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", + "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", + "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush", +] + +# Map friendly model names to .hef filenames +HAILO_VARIANTS: dict[str, str] = { + "yolov8n": "yolov8n.hef", + "yolov8s": "yolov8s.hef", + "yolov8m": "yolov8m.hef", + "yolov11n": "yolov11n.hef", + "yolov11s": "yolov11s.hef", +} + +# Default directory for .hef model files +DEFAULT_HEF_DIR = Path("/models") + + +def _letterbox( + image: np.ndarray, + target_size: tuple[int, int], + color: tuple[int, int, int] = (114, 114, 114), +) -> tuple[np.ndarray, float, tuple[int, int]]: + """Resize and letterbox an image to the target size. + + Maintains aspect ratio by padding with the specified color. + + Args: + image: Input RGB image as numpy array (H, W, 3). + target_size: (height, width) of the model input. + color: Padding fill color (default: gray). + + Returns: + Tuple of (letterboxed_image, scale, (pad_x, pad_y)). + """ + h, w = image.shape[:2] + th, tw = target_size + + scale = min(tw / w, th / h) + new_w, new_h = int(w * scale), int(h * scale) + + from PIL import Image + + resized = np.array( + Image.fromarray(image).resize((new_w, new_h), Image.BILINEAR) + ) + + canvas = np.full((th, tw, 3), color, dtype=np.uint8) + pad_x = (tw - new_w) // 2 + pad_y = (th - new_h) // 2 + canvas[pad_y : pad_y + new_h, pad_x : pad_x + new_w] = resized + + return canvas, scale, (pad_x, pad_y) + + +def _parse_nms_output( + output: np.ndarray, + scale: float, + pad: tuple[int, int], + image_width: int, + image_height: int, + confidence_threshold: float, + class_filter: set[int] | None = None, + input_size: tuple[int, int] = (640, 640), +) -> list[DetectionBox]: + """Parse NMS-decoded output from a Hailo .hef YOLO model. + + Hailo Model Zoo YOLO .hef files with built-in NMS produce a flat + per-class buffer. For 80 COCO classes with 100 max detections the + raw shape is ``(40080,)`` = 80 × (1 + 100 × 5). + + Per-class layout (stride = 1 + max_det × 5): + [count, y1, x1, y2, x2, score, y1, x1, y2, x2, score, …] + + ``count`` is the number of valid detections for that class. + Each detection is 5 floats: ``[y_min, x_min, y_max, x_max, score]``. + Coordinates are normalized (0.0–1.0) relative to the letterboxed input. + + Args: + output: Raw float32 output array from Hailo inference. + scale: Scale factor from letterboxing. + pad: (pad_x, pad_y) offset from letterboxing. + image_width: Original image width (for coordinate rescaling). + image_height: Original image height (for coordinate rescaling). + confidence_threshold: Minimum confidence to keep. + class_filter: Optional set of class IDs to keep. + input_size: (height, width) of the model input in pixels. + + Returns: + List of DetectionBox instances in original image coordinates. + """ + boxes: list[DetectionBox] = [] + flat = output.flatten() + total = flat.size + + logger.debug(f"Hailo NMS output shape: {output.shape}, flat size: {total}") + + # Determine num_classes and max_det from buffer size. + # Buffer = num_classes × (1 + max_det × 5). + # COCO models use 80 classes; try common max_det values. + num_classes = 0 + max_det = 0 + for nc in (80,): + if total % nc != 0: + continue + stride = total // nc + # stride = 1 + max_det * 5 → (stride - 1) must be divisible by 5 + if (stride - 1) % 5 == 0: + num_classes = nc + max_det = (stride - 1) // 5 + break + + if num_classes == 0: + logger.warning( + f"Cannot parse Hailo NMS output: flat size {total} does not match " + f"expected num_classes × (1 + max_det × 5) layout." + ) + return boxes + + stride = 1 + max_det * 5 + logger.debug( + f"Hailo NMS: {num_classes} classes, {max_det} max detections per class, " + f"stride {stride}" + ) + + pad_x, pad_y = pad + input_h, input_w = input_size + + for cls_id in range(num_classes): + if class_filter is not None and cls_id not in class_filter: + continue + + class_name = ( + COCO_CLASS_NAMES[cls_id] + if cls_id < len(COCO_CLASS_NAMES) + else f"class_{cls_id}" + ) + + offset = cls_id * stride + n_det = int(flat[offset]) + if n_det <= 0: + continue + n_det = min(n_det, max_det) # safety clamp + + for i in range(n_det): + base = offset + 1 + i * 5 + y1_norm = float(flat[base]) + x1_norm = float(flat[base + 1]) + y2_norm = float(flat[base + 2]) + x2_norm = float(flat[base + 3]) + score = float(flat[base + 4]) + + if score < confidence_threshold: + continue + + logger.debug( + f"Hailo det: class={class_name}({cls_id}) score={score:.4f} " + f"norm=[{y1_norm:.4f}, {x1_norm:.4f}, {y2_norm:.4f}, {x2_norm:.4f}]" + ) + + # Convert normalized coords to pixel space in letterboxed image + x1_px = x1_norm * input_w + y1_px = y1_norm * input_h + x2_px = x2_norm * input_w + y2_px = y2_norm * input_h + + # Remove letterbox padding and rescale to original image + x1 = max(0.0, (x1_px - pad_x) / scale) + y1 = max(0.0, (y1_px - pad_y) / scale) + x2 = min(float(image_width), (x2_px - pad_x) / scale) + y2 = min(float(image_height), (y2_px - pad_y) / scale) + + logger.debug( + f"Hailo mapped: px=({x1_px:.1f},{y1_px:.1f},{x2_px:.1f},{y2_px:.1f}) " + f"-> orig=({x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f})" + ) + + boxes.append( + DetectionBox( + x1=x1, y1=y1, x2=x2, y2=y2, + class_name=class_name, + class_id=cls_id, + confidence=score, + ) + ) + + return boxes + + +class HailoYOLOModel(DetectionModel): + """YOLO detection model running on Hailo-10H AI accelerator. + + Uses pre-compiled .hef models from the Hailo Model Zoo. These models + include built-in NMS so the output is already decoded. + + Requires the hailo_platform package (provided by the hailort wheel). + """ + + def __init__( + self, + model_id: str = "yolov11n", + device: str = "hailo", + confidence_threshold: float = 0.5, + hef_dir: str | Path | None = None, + token: str | None = None, + ): + super().__init__(model_id, device="hailo", confidence_threshold=confidence_threshold, token=token) + self._hef_dir = Path(hef_dir) if hef_dir else DEFAULT_HEF_DIR + self._vdevice: Any = None + self._infer_model: Any = None + self._configured: Any = None + self._input_shape: tuple[int, int] | None = None # (height, width) + self._hef_path: str | None = None + + def _resolve_hef_path(self) -> Path: + """Resolve the .hef file path from model_id.""" + # Check variant map first + hef_name = HAILO_VARIANTS.get(self.model_id) + if hef_name: + path = self._hef_dir / hef_name + if path.exists(): + return path + + # Try model_id directly as filename + if self.model_id.endswith(".hef"): + path = self._hef_dir / Path(self.model_id).name + else: + path = self._hef_dir / f"{self.model_id}.hef" + + if path.exists(): + return path + + # Try VISION_MODELS_DIR fallback + from utils.safe_home import get_data_dir + vision_dir = get_data_dir() / "models" / "vision" + alt_path = vision_dir / path.name + if alt_path.exists(): + return alt_path + + raise FileNotFoundError( + f"HEF model not found: tried {path} and {alt_path}. " + f"Available in {self._hef_dir}: " + f"{[f.name for f in self._hef_dir.glob('*.hef')] if self._hef_dir.exists() else '(dir missing)'}" + ) + + async def load(self) -> None: + if self._loaded: + return + + from hailo_platform import FormatType, VDevice + + logger.info(f"Loading Hailo model {self.model_id}") + start = time.perf_counter() + + hef_path = self._resolve_hef_path() + self._hef_path = str(hef_path) + logger.info(f"HEF file: {hef_path}") + + def _load(): + vdevice = VDevice() + infer_model = vdevice.create_infer_model(str(hef_path)) + infer_model.output().set_format_type(FormatType.FLOAT32) + configured = infer_model.configure() + + # Extract input dimensions from the model + input_vstream = infer_model.input() + shape = input_vstream.shape # (H, W, C) or (C, H, W) + if len(shape) == 3: + if shape[2] == 3: # HWC + input_shape = (shape[0], shape[1]) + else: # CHW + input_shape = (shape[1], shape[2]) + else: + input_shape = (640, 640) # Default YOLO input size + logger.warning(f"Unexpected input shape {shape}, defaulting to 640x640") + + output_shape = infer_model.output().shape + logger.info( + f"Hailo model shapes — input: {shape}, output: {output_shape}" + ) + + return vdevice, infer_model, configured, input_shape + + self._vdevice, self._infer_model, self._configured, self._input_shape = ( + await asyncio.to_thread(_load) + ) + + self.class_names = list(COCO_CLASS_NAMES) + self._loaded = True + elapsed = (time.perf_counter() - start) * 1000 + logger.info( + f"Hailo model loaded in {elapsed:.0f}ms " + f"(input: {self._input_shape[1]}x{self._input_shape[0]}, " + f"{len(self.class_names)} classes)" + ) + + async def unload(self) -> None: + if self._configured is not None: + del self._configured + self._configured = None + if self._infer_model is not None: + del self._infer_model + self._infer_model = None + if self._vdevice is not None: + del self._vdevice + self._vdevice = None + self._loaded = False + logger.info(f"Hailo model unloaded: {self.model_id}") + + async def detect( + self, + image: bytes | np.ndarray, + confidence_threshold: float | None = None, + classes: list[str] | None = None, + ) -> DetectionResult: + if not self._loaded or self._configured is None: + await self.load() + + start = time.perf_counter() + img_array = self._image_to_numpy(image) + height, width = img_array.shape[:2] + conf = confidence_threshold if confidence_threshold is not None else self.confidence_threshold + + # Build class filter + class_filter: set[int] | None = None + if classes: + class_filter = {i for i, n in enumerate(self.class_names) if n in classes} + + # Preprocess: letterbox to model input dimensions + input_h, input_w = self._input_shape or (640, 640) + letterboxed, scale, pad = _letterbox(img_array, (input_h, input_w)) + + # Ensure uint8 RGB contiguous array + input_data = np.ascontiguousarray(letterboxed, dtype=np.uint8) + + # Run inference on Hailo + def _infer(): + bindings = self._configured.create_bindings() + bindings.input().set_buffer(input_data) + output_buffer = np.empty( + self._infer_model.output().shape, dtype=np.float32 + ) + bindings.output().set_buffer(output_buffer) + self._configured.run([bindings], 5000) + return output_buffer + + output = await asyncio.to_thread(_infer) + inference_time = (time.perf_counter() - start) * 1000 + + # Parse NMS output into detection boxes + boxes = _parse_nms_output( + output, scale, pad, width, height, conf, class_filter, + input_size=(input_h, input_w), + ) + + return DetectionResult( + confidence=max((b.confidence for b in boxes), default=0.0), + inference_time_ms=inference_time, + model_name=self.model_id, + boxes=boxes, + class_names=list({b.class_name for b in boxes}), + image_width=width, + image_height=height, + ) + + def get_model_info(self) -> dict: + info = super().get_model_info() + info.update({ + "backend": "hailo", + "variant": self.model_id, + "hef_path": self._hef_path, + "input_shape": self._input_shape, + "num_classes": len(self.class_names), + }) + return info diff --git a/runtimes/edge/models/language_model.py b/runtimes/edge/models/language_model.py new file mode 100644 index 000000000..9c627fcb3 --- /dev/null +++ b/runtimes/edge/models/language_model.py @@ -0,0 +1,223 @@ +""" +Language model wrapper for text generation or embedding. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncGenerator +from threading import Thread +from typing import cast + +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class LanguageModel(BaseModel): + """Wrapper for HuggingFace language models (GPT-style text generation).""" + + def __init__(self, model_id: str, device: str, token: str | None = None): + super().__init__(model_id, device, token=token) + self.model_type = "language" + self.supports_streaming = True + + async def load(self) -> None: + """Load the causal language model. + + All blocking transformers operations are wrapped in asyncio.to_thread() + to avoid blocking the FastAPI event loop during model loading. + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + logger.info(f"Loading causal LM: {self.model_id}") + + dtype = self.get_dtype() + + # Load tokenizer - wrapped to avoid blocking event loop + self.tokenizer = await asyncio.to_thread( + AutoTokenizer.from_pretrained, + self.model_id, + trust_remote_code=True, + token=self.token, + ) + + # Load model - wrapped to avoid blocking event loop + # This is the heaviest operation (downloads/loads model weights) + self.model = await asyncio.to_thread( + AutoModelForCausalLM.from_pretrained, + self.model_id, + dtype=dtype, + trust_remote_code=True, + device_map="auto" if self.device == "cuda" else None, + token=self.token, + ) + + if self.device != "cuda" and self.model is not None: + # Move to device - wrapped for consistency + self.model = await asyncio.to_thread(self.model.to, self.device) # type: ignore[arg-type] + + logger.info(f"Causal LM loaded on {self.device}") + + def format_messages(self, messages: list[dict]) -> str: + """Format chat messages into a prompt.""" + # Try to use tokenizer's chat template if available + if self.tokenizer and hasattr(self.tokenizer, "apply_chat_template"): + try: + result = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # apply_chat_template with tokenize=False returns str + if isinstance(result, str): + return result + except Exception: + # Fall through to simple concatenation if template fails + logger.debug("Chat template application failed, using fallback", exc_info=True) + + # Fallback to simple concatenation + prompt_parts = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + prompt_parts.append(f"{role.capitalize()}: {content}") + + prompt_parts.append("Assistant:") + return "\n".join(prompt_parts) + + async def generate( + self, + messages: list[dict], + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + thinking_budget: int | None = None, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + ) -> str: + """Generate chat completion. + + Uses the tokenizer's chat template to format messages before generation. + + Args: + messages: List of message dicts with 'role' and 'content' keys + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature (0.0 = greedy, higher = more random) + top_p: Nucleus sampling threshold + stop: List of stop sequences to end generation + thinking_budget: Not used for transformers models (included for API compatibility) + tools: Not used for transformers models (included for API compatibility) + tool_choice: Not used for transformers models (included for API compatibility) + + Returns: + Generated text as a string + """ + assert self.model is not None, "Model not loaded" + assert self.tokenizer is not None, "Tokenizer not loaded" + + # Format messages using chat template + prompt = self.format_messages(messages) + + import torch + + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + + max_new_tokens = max_tokens or 512 + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=temperature > 0, + pad_token_id=self.tokenizer.eos_token_id, + ) + + # Decode only the new tokens + generated_text = self.tokenizer.decode( + outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True + ) + + return generated_text.strip() + + async def generate_stream( + self, + messages: list[dict], + max_tokens: int | None = None, + temperature: float = 0.7, + top_p: float = 1.0, + stop: list[str] | None = None, + thinking_budget: int | None = None, + tools: list[dict] | None = None, + tool_choice: str | dict | None = None, + kv_cache_data: bytes | None = None, + kv_cache_tokens: int = 0, + ) -> AsyncGenerator[str, None]: + """Generate chat completion with streaming (yields tokens as they're generated). + + Uses the tokenizer's chat template to format messages before generation. + + Args: + messages: List of message dicts with 'role' and 'content' keys + max_tokens: Maximum tokens to generate (default: 512) + temperature: Sampling temperature (0.0 = greedy, higher = more random) + top_p: Nucleus sampling threshold + stop: List of stop sequences to end generation + thinking_budget: Not used for transformers models (included for API compatibility) + tools: Not used for transformers models (included for API compatibility) + tool_choice: Not used for transformers models (included for API compatibility) + + Yields: + Generated text tokens as strings + """ + assert self.model is not None, "Model not loaded" + assert self.tokenizer is not None, "Tokenizer not loaded" + + # Format messages using chat template + prompt = self.format_messages(messages) + + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + max_new_tokens = max_tokens or 512 + + # Create a streamer that will yield tokens as they're generated + from transformers import AutoTokenizer, TextIteratorStreamer + + streamer = TextIteratorStreamer( + cast(AutoTokenizer, self.tokenizer), + skip_prompt=True, + skip_special_tokens=True, + ) + + # Generation kwargs + generation_kwargs = { + **inputs, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "do_sample": temperature > 0, + "pad_token_id": self.tokenizer.eos_token_id, + "streamer": streamer, + } + + # Run generation in a separate thread so we can stream the results + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) # type: ignore[arg-type] + thread.start() + + # Yield tokens as they become available + for text in streamer: + # Check for stop sequences + if stop: + for stop_seq in stop: + if stop_seq in text: + # Yield up to the stop sequence + idx = text.index(stop_seq) + if idx > 0: + yield text[:idx] + thread.join() + return + yield text + + # Wait for generation to complete + thread.join() diff --git a/runtimes/edge/models/vision_base.py b/runtimes/edge/models/vision_base.py new file mode 100644 index 000000000..ee3931996 --- /dev/null +++ b/runtimes/edge/models/vision_base.py @@ -0,0 +1,188 @@ +"""Base classes for vision models (detection, classification). + +Simplified MVP — no segmentation, no embedding model base. +""" + +from __future__ import annotations + +import logging +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Any, Literal + +import numpy as np + +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Result Dataclasses +# ============================================================================= + + +@dataclass +class VisionResult: + """Base result for all vision operations.""" + confidence: float + inference_time_ms: float + model_name: str + + +@dataclass +class DetectionBox: + """Single detection bounding box.""" + x1: float + y1: float + x2: float + y2: float + class_name: str + class_id: int + confidence: float + + +@dataclass +class DetectionResult(VisionResult): + """Object detection result.""" + boxes: list[DetectionBox] = field(default_factory=list) + class_names: list[str] = field(default_factory=list) + image_width: int = 0 + image_height: int = 0 + + +@dataclass +class ClassificationResult(VisionResult): + """Image classification result.""" + class_name: str = "" + class_id: int = 0 + all_scores: dict[str, float] = field(default_factory=dict) + + +@dataclass +class EmbeddingResult(VisionResult): + """Image/text embedding result.""" + embeddings: list[list[float]] = field(default_factory=list) + dimensions: int = 0 + + +# ============================================================================= +# Base Model Classes +# ============================================================================= + + +class VisionModel(BaseModel): + """Base class for all vision models.""" + + def __init__(self, model_id: str, device: str = "auto", token: str | None = None): + super().__init__(model_id, device, token) + self.model_type = "vision" + self._loaded = False + + def _resolve_device(self, device: str) -> str: + if device != "auto": + return device + try: + import torch + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + except ImportError: + pass # torch not installed — fall back to CPU + return "cpu" + + def _image_to_numpy(self, image: bytes | np.ndarray) -> np.ndarray: + if isinstance(image, np.ndarray): + return image + import io + + from PIL import Image, UnidentifiedImageError + try: + img = Image.open(io.BytesIO(image)) + img.load() # Force eager decode so errors surface here, not lazily later + except UnidentifiedImageError as e: + raise ValueError( + "Cannot identify image format. " + "Ensure the image is a valid JPEG, PNG, BMP, TIFF, or WebP file." + ) from e + except OSError as e: + raise ValueError(f"Failed to decode image data: {e}") from e + if img.mode != "RGB": + img = img.convert("RGB") + return np.array(img) + + def _image_to_pil(self, image: bytes | np.ndarray): + import io + + from PIL import Image, UnidentifiedImageError + if isinstance(image, np.ndarray): + return Image.fromarray(image) + try: + img = Image.open(io.BytesIO(image)) + img.load() # Force eager decode + except UnidentifiedImageError as e: + raise ValueError( + "Cannot identify image format. " + "Ensure the image is a valid JPEG, PNG, BMP, TIFF, or WebP file." + ) from e + except OSError as e: + raise ValueError(f"Failed to decode image data: {e}") from e + if img.mode != "RGB": + img = img.convert("RGB") + return img + + def get_model_info(self) -> dict[str, Any]: + info = super().get_model_info() + info["loaded"] = self._loaded + return info + + +class DetectionModel(VisionModel): + """Base class for object detection models.""" + + def __init__(self, model_id: str, device: str = "auto", + confidence_threshold: float = 0.5, token: str | None = None): + super().__init__(model_id, device, token) + self.confidence_threshold = confidence_threshold + self.class_names: list[str] = [] + + @abstractmethod + async def detect(self, image: bytes | np.ndarray, + confidence_threshold: float | None = None, + classes: list[str] | None = None) -> DetectionResult: + pass + + async def train(self, dataset_path: str, epochs: int = 10, + batch_size: int = 16, **kwargs) -> dict: + raise NotImplementedError(f"{self.__class__.__name__} does not support training") + + async def export(self, format: Literal["onnx", "coreml", "tensorrt", "tflite", "openvino"], + output_path: str, **kwargs) -> str: + raise NotImplementedError(f"{self.__class__.__name__} does not support export to {format}") + + async def load(self) -> None: + raise NotImplementedError + + async def infer(self, image: bytes | np.ndarray, **kwargs) -> VisionResult: + return await self.detect(image, **kwargs) + + +class ClassificationModel(VisionModel): + """Base class for image classification models.""" + + def __init__(self, model_id: str, device: str = "auto", token: str | None = None): + super().__init__(model_id, device, token) + self.class_names: list[str] = [] + + @abstractmethod + async def classify(self, image: bytes | np.ndarray, + classes: list[str] | None = None, + top_k: int = 5) -> ClassificationResult: + pass + + async def load(self) -> None: + raise NotImplementedError + + async def infer(self, image: bytes | np.ndarray, **kwargs) -> VisionResult: + return await self.classify(image, classes=kwargs.get("classes"), top_k=kwargs.get("top_k", 5)) diff --git a/runtimes/edge/models/yolo_model.py b/runtimes/edge/models/yolo_model.py new file mode 100644 index 000000000..e8db694ae --- /dev/null +++ b/runtimes/edge/models/yolo_model.py @@ -0,0 +1,182 @@ +"""YOLO-based object detection model. Supports YOLOv8/v11 via ultralytics.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import numpy as np + +from .vision_base import DetectionBox, DetectionModel, DetectionResult + +if TYPE_CHECKING: + from ultralytics import YOLO + +logger = logging.getLogger(__name__) + +YOLO_VARIANTS = { + "yolov8n": "yolov8n.pt", "yolov8s": "yolov8s.pt", "yolov8m": "yolov8m.pt", + "yolov8l": "yolov8l.pt", "yolov8x": "yolov8x.pt", + "yolov11n": "yolo11n.pt", "yolov11s": "yolo11s.pt", "yolov11m": "yolo11m.pt", +} + + +class YOLOModel(DetectionModel): + """YOLO object detection model wrapper.""" + + def __init__(self, model_id: str = "yolov8n", device: str = "auto", + confidence_threshold: float = 0.5, token: str | None = None): + super().__init__(model_id, device, confidence_threshold, token) + self.yolo: YOLO | None = None + self._model_path: str | None = None + + async def load(self) -> None: + if self._loaded: + return + # Suppress missing pi_heif — some ultralytics builds register the HEIF PIL + # plugin unconditionally, causing an unhandled ImportError on first inference + # when the optional `pi_heif` package is not installed. + try: + import pi_heif + pi_heif.register_heif_opener() + except ImportError: + # Optional — continue without HEIF image support + logger.debug("pi_heif not available, HEIF support disabled") + from ultralytics import YOLO + + self.device = self._resolve_device(self.device) + logger.info(f"Loading YOLO model {self.model_id} on {self.device}") + start = time.perf_counter() + + if self.model_id in YOLO_VARIANTS: + self._model_path = YOLO_VARIANTS[self.model_id] + elif ".." not in Path(self.model_id).parts: + # Validate path — must resolve within home/.llamafarm or cwd + resolved = Path(self.model_id).resolve() + allowed_roots = [Path.home() / ".llamafarm", Path.cwd()] + if not any( + str(resolved).startswith(str(r.resolve()) + os.sep) + for r in allowed_roots + ): + raise ValueError(f"Model path outside allowed directories: {self.model_id}") + if not resolved.exists(): + raise FileNotFoundError(f"Model file not found: {self.model_id}") + self._model_path = str(resolved) + else: + # Basename only for dynamic model IDs (no path components) + safe_id = Path(self.model_id).name + if safe_id != self.model_id: + raise ValueError(f"Invalid model_id: {self.model_id}") + self._model_path = f"{safe_id}.pt" + + self.yolo = YOLO(self._model_path) + if self.device != "cpu": + self.yolo.to(self.device) + + self.class_names = list(self.yolo.names.values()) if hasattr(self.yolo, "names") else [] + self._loaded = True + logger.info(f"YOLO loaded in {(time.perf_counter() - start) * 1000:.0f}ms ({len(self.class_names)} classes)") + + async def unload(self) -> None: + if self.yolo is not None: + del self.yolo + self.yolo = None + self._loaded = False + await super().unload() + + async def detect(self, image: bytes | np.ndarray, + confidence_threshold: float | None = None, + classes: list[str] | None = None) -> DetectionResult: + if not self._loaded or self.yolo is None: + await self.load() + + start = time.perf_counter() + img_array = self._image_to_numpy(image) + height, width = img_array.shape[:2] + conf = confidence_threshold if confidence_threshold is not None else self.confidence_threshold + + class_indices = None + if classes: + class_indices = [i for i, n in enumerate(self.class_names) if n in classes] + + results = await asyncio.to_thread( + self.yolo, img_array, conf=conf, classes=class_indices, verbose=False + ) + inference_time = (time.perf_counter() - start) * 1000 + + boxes: list[DetectionBox] = [] + if results and len(results) > 0 and results[0].boxes is not None: + for box in results[0].boxes: + xyxy = box.xyxy[0].cpu().numpy() + cls_id = int(box.cls[0].cpu().numpy()) + boxes.append(DetectionBox( + x1=float(xyxy[0]), y1=float(xyxy[1]), + x2=float(xyxy[2]), y2=float(xyxy[3]), + class_name=self.class_names[cls_id] if cls_id < len(self.class_names) else f"class_{cls_id}", + class_id=cls_id, + confidence=float(box.conf[0].cpu().numpy()), + )) + + return DetectionResult( + confidence=max((b.confidence for b in boxes), default=0.0), + inference_time_ms=inference_time, model_name=self.model_id, + boxes=boxes, class_names=list({b.class_name for b in boxes}), + image_width=width, image_height=height, + ) + + async def train(self, dataset_path: str, epochs: int = 10, + batch_size: int = 16, **kwargs) -> dict: + if not self._loaded or self.yolo is None: + await self.load() + + logger.info(f"Starting YOLO training: {epochs} epochs, batch {batch_size}") + train_args = { + "data": dataset_path, "epochs": epochs, "batch": batch_size, + "device": self.device if self.device != "auto" else None, + "imgsz": kwargs.get("imgsz", 640), + "patience": kwargs.get("patience", 50), + "save": True, "verbose": kwargs.get("verbose", True), + } + results = await asyncio.to_thread(self.yolo.train, **train_args) + + metrics = {} + if hasattr(results, "results_dict"): + metrics = results.results_dict + return { + "metrics": metrics, "epochs": epochs, + "model_path": str(results.save_dir) if hasattr(results, "save_dir") else None, + } + + async def export(self, format: Literal["onnx", "coreml", "tensorrt", "tflite", "openvino"], + output_path: str, **kwargs) -> str: + if not self._loaded or self.yolo is None: + await self.load() + + logger.info(f"Exporting YOLO model to {format}") + format_map = {"onnx": "onnx", "coreml": "coreml", "tensorrt": "engine", + "tflite": "tflite", "openvino": "openvino"} + + export_path = self.yolo.export( + format=format_map.get(format, format), + half=kwargs.get("half", False), + int8=kwargs.get("int8", False), + simplify=kwargs.get("simplify", True), + ) + + if output_path and Path(output_path).is_dir(): + import shutil + final = Path(output_path) / Path(export_path).name + shutil.move(export_path, final) + export_path = str(final) + + return str(export_path) + + def get_model_info(self) -> dict: + info = super().get_model_info() + info.update({"variant": self.model_id, "num_classes": len(self.class_names), + "model_path": self._model_path}) + return info diff --git a/runtimes/edge/openapi.json b/runtimes/edge/openapi.json new file mode 100644 index 000000000..00b622541 --- /dev/null +++ b/runtimes/edge/openapi.json @@ -0,0 +1 @@ +{"openapi":"3.1.0","info":{"title":"LlamaFarm Edge Runtime","description":"Minimal on-device inference API for drones and edge hardware","version":"0.1.0"},"paths":{"/health":{"get":{"tags":["health"],"summary":"Health Check","description":"Health check endpoint with device information.","operationId":"health_check_health_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/v1/models":{"get":{"tags":["health"],"summary":"List Models","description":"List currently loaded models.","operationId":"list_models_v1_models_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/v1/chat/completions":{"post":{"summary":"Chat Completions","description":"OpenAI-compatible chat completions endpoint.\n\nSupports any HuggingFace causal language model.","operationId":"chat_completions_v1_chat_completions_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ChatCompletionRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/detect":{"post":{"tags":["vision","vision-detection"],"summary":"Detect Objects","description":"Detect objects in an image using YOLO.","operationId":"detect_objects_v1_vision_detect_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/DetectRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/DetectResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/classify":{"post":{"tags":["vision","vision-classification"],"summary":"Classify Image","description":"Classify an image using CLIP (zero-shot).","operationId":"classify_image_v1_vision_classify_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ClassifyRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ClassifyResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/detect_classify":{"post":{"tags":["vision","vision-detect-classify"],"summary":"Detect And Classify","description":"Detect objects then classify each crop — single round-trip.\n\nRuns YOLO detection → crops each bounding box → CLIP classifies each crop.\nReturns unified results with both detection and classification info.","operationId":"detect_and_classify_v1_vision_detect_classify_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/DetectClassifyRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/DetectClassifyResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/stream/start":{"post":{"tags":["vision","vision-streaming"],"summary":"Start Stream","description":"Start a streaming detection session with cascade config.","operationId":"start_stream_v1_vision_stream_start_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamStartRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamStartResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/stream/frame":{"post":{"tags":["vision","vision-streaming"],"summary":"Process Frame","description":"Process a frame through the cascade chain.","operationId":"process_frame_v1_vision_stream_frame_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamFrameRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamFrameResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/stream/stop":{"post":{"tags":["vision","vision-streaming"],"summary":"Stop Stream","description":"Stop a streaming session.","operationId":"stop_stream_v1_vision_stream_stop_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamStopRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/StreamStopResponse"}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/vision/stream/sessions":{"get":{"tags":["vision","vision-streaming"],"summary":"List Sessions","description":"List active streaming sessions.","operationId":"list_sessions_v1_vision_stream_sessions_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{"$ref":"#/components/schemas/SessionsListResponse"}}}}}}},"/v1/models/unload":{"post":{"tags":["models"],"summary":"Unload All Models","description":"Unload all loaded models to free memory.","operationId":"unload_all_models_v1_models_unload_post","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}}},"components":{"schemas":{"Audio":{"properties":{"id":{"type":"string","title":"Id"}},"type":"object","required":["id"],"title":"Audio","description":"Data about a previous audio response from the model.\n[Learn more](https://platform.openai.com/docs/guides/audio)."},"BoundingBox":{"properties":{"x1":{"type":"number","title":"X1"},"y1":{"type":"number","title":"Y1"},"x2":{"type":"number","title":"X2"},"y2":{"type":"number","title":"Y2"}},"type":"object","required":["x1","y1","x2","y2"],"title":"BoundingBox"},"CascadeConfigRequest":{"properties":{"chain":{"items":{"type":"string"},"type":"array","title":"Chain","description":"Model chain, can include 'remote:http://...'","default":["yolov8n"]},"confidence_threshold":{"type":"number","maximum":1.0,"minimum":0.0,"title":"Confidence Threshold","default":0.7}},"type":"object","title":"CascadeConfigRequest"},"ChatCompletionAssistantMessageParam":{"properties":{"role":{"type":"string","const":"assistant","title":"Role"},"audio":{"anyOf":[{"$ref":"#/components/schemas/Audio"},{"type":"null"}]},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartRefusalParam"}]},"type":"array"},{"type":"null"}],"title":"Content"},"function_call":{"anyOf":[{"$ref":"#/components/schemas/FunctionCall"},{"type":"null"}]},"name":{"type":"string","title":"Name"},"refusal":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Refusal"},"tool_calls":{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionMessageFunctionToolCallParam"},{"$ref":"#/components/schemas/ChatCompletionMessageCustomToolCallParam"}]},"type":"array","title":"Tool Calls"}},"type":"object","required":["role"],"title":"ChatCompletionAssistantMessageParam","description":"Messages sent by the model in response to user messages."},"ChatCompletionContentPartImageParam":{"properties":{"image_url":{"$ref":"#/components/schemas/ImageURL"},"type":{"type":"string","const":"image_url","title":"Type"}},"type":"object","required":["image_url","type"],"title":"ChatCompletionContentPartImageParam","description":"Learn about [image inputs](https://platform.openai.com/docs/guides/vision)."},"ChatCompletionContentPartInputAudioParam":{"properties":{"input_audio":{"$ref":"#/components/schemas/InputAudio"},"type":{"type":"string","const":"input_audio","title":"Type"}},"type":"object","required":["input_audio","type"],"title":"ChatCompletionContentPartInputAudioParam","description":"Learn about [audio inputs](https://platform.openai.com/docs/guides/audio)."},"ChatCompletionContentPartRefusalParam":{"properties":{"refusal":{"type":"string","title":"Refusal"},"type":{"type":"string","const":"refusal","title":"Type"}},"type":"object","required":["refusal","type"],"title":"ChatCompletionContentPartRefusalParam"},"ChatCompletionContentPartTextParam":{"properties":{"text":{"type":"string","title":"Text"},"type":{"type":"string","const":"text","title":"Type"}},"type":"object","required":["text","type"],"title":"ChatCompletionContentPartTextParam","description":"Learn about [text inputs](https://platform.openai.com/docs/guides/text-generation)."},"ChatCompletionDeveloperMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","const":"developer","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionDeveloperMessageParam","description":"Developer-provided instructions that the model should follow, regardless of\nmessages sent by the user. With o1 models and newer, `developer` messages\nreplace the previous `system` messages."},"ChatCompletionFunctionMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Content"},"name":{"type":"string","title":"Name"},"role":{"type":"string","const":"function","title":"Role"}},"type":"object","required":["content","name","role"],"title":"ChatCompletionFunctionMessageParam"},"ChatCompletionFunctionToolParam":{"properties":{"function":{"$ref":"#/components/schemas/FunctionDefinition"},"type":{"type":"string","const":"function","title":"Type"}},"type":"object","required":["function","type"],"title":"ChatCompletionFunctionToolParam","description":"A function tool that can be used to generate a response."},"ChatCompletionMessageCustomToolCallParam":{"properties":{"id":{"type":"string","title":"Id"},"custom":{"$ref":"#/components/schemas/Custom"},"type":{"type":"string","const":"custom","title":"Type"}},"type":"object","required":["id","custom","type"],"title":"ChatCompletionMessageCustomToolCallParam","description":"A call to a custom tool created by the model."},"ChatCompletionMessageFunctionToolCallParam":{"properties":{"id":{"type":"string","title":"Id"},"function":{"$ref":"#/components/schemas/Function"},"type":{"type":"string","const":"function","title":"Type"}},"type":"object","required":["id","function","type"],"title":"ChatCompletionMessageFunctionToolCallParam","description":"A call to a function tool created by the model."},"ChatCompletionRequest":{"properties":{"model":{"type":"string","title":"Model"},"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionDeveloperMessageParam"},{"$ref":"#/components/schemas/ChatCompletionSystemMessageParam"},{"$ref":"#/components/schemas/ChatCompletionUserMessageParam"},{"$ref":"#/components/schemas/ChatCompletionAssistantMessageParam"},{"$ref":"#/components/schemas/ChatCompletionToolMessageParam"},{"$ref":"#/components/schemas/ChatCompletionFunctionMessageParam"}]},"type":"array","title":"Messages"},"temperature":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Temperature","default":1.0},"top_p":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Top P","default":1.0},"max_tokens":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Max Tokens"},"stream":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Stream","default":false},"stop":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Stop"},"logprobs":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Logprobs"},"top_logprobs":{"anyOf":[{"type":"integer","maximum":20.0,"minimum":0.0},{"type":"null"}],"title":"Top Logprobs"},"presence_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Presence Penalty","default":0.0},"frequency_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Frequency Penalty","default":0.0},"user":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User"},"n_ctx":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"N Ctx"},"n_batch":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"N Batch"},"n_gpu_layers":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"N Gpu Layers"},"n_threads":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"N Threads"},"flash_attn":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Flash Attn"},"use_mmap":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Use Mmap"},"use_mlock":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Use Mlock"},"cache_type_k":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Cache Type K"},"cache_type_v":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Cache Type V"},"extra_body":{"anyOf":[{"additionalProperties":true,"type":"object"},{"type":"null"}],"title":"Extra Body"},"tools":{"anyOf":[{"items":{"$ref":"#/components/schemas/ChatCompletionFunctionToolParam"},"type":"array"},{"type":"null"}],"title":"Tools"},"tool_choice":{"anyOf":[{"type":"string"},{"additionalProperties":true,"type":"object"},{"type":"null"}],"title":"Tool Choice"},"think":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Think"},"thinking_budget":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Thinking Budget"},"cache_key":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Cache Key"},"return_cache_key":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Return Cache Key"},"auto_truncate":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Auto Truncate","default":true},"truncation_strategy":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Truncation Strategy"}},"type":"object","required":["model","messages"],"title":"ChatCompletionRequest","description":"OpenAI-compatible chat completion request."},"ChatCompletionSystemMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","const":"system","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionSystemMessageParam","description":"Developer-provided instructions that the model should follow, regardless of\nmessages sent by the user. With o1 models and newer, use `developer` messages\nfor this purpose instead."},"ChatCompletionToolMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","const":"tool","title":"Role"},"tool_call_id":{"type":"string","title":"Tool Call Id"}},"type":"object","required":["content","role","tool_call_id"],"title":"ChatCompletionToolMessageParam"},"ChatCompletionUserMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartImageParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartInputAudioParam"},{"$ref":"#/components/schemas/File"}]},"type":"array"}],"title":"Content"},"role":{"type":"string","const":"user","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionUserMessageParam","description":"Messages sent by an end user, containing prompts or additional context\ninformation."},"ClassifiedDetection":{"properties":{"box":{"$ref":"#/components/schemas/BoundingBox"},"detection_class":{"type":"string","title":"Detection Class"},"detection_confidence":{"type":"number","title":"Detection Confidence"},"classification":{"type":"string","title":"Classification"},"classification_confidence":{"type":"number","title":"Classification Confidence"},"all_scores":{"additionalProperties":{"type":"number"},"type":"object","title":"All Scores"}},"type":"object","required":["box","detection_class","detection_confidence","classification","classification_confidence","all_scores"],"title":"ClassifiedDetection","description":"A detection with classification results."},"ClassifyRequest":{"properties":{"image":{"type":"string","title":"Image","description":"Base64-encoded image"},"model":{"type":"string","title":"Model","default":"clip-vit-base"},"classes":{"items":{"type":"string"},"type":"array","title":"Classes","description":"Classes for zero-shot classification"},"top_k":{"type":"integer","maximum":100.0,"minimum":1.0,"title":"Top K","default":5}},"type":"object","required":["image","classes"],"title":"ClassifyRequest"},"ClassifyResponse":{"properties":{"class_name":{"type":"string","title":"Class Name"},"class_id":{"type":"integer","title":"Class Id"},"confidence":{"type":"number","title":"Confidence"},"all_scores":{"additionalProperties":{"type":"number"},"type":"object","title":"All Scores"},"model":{"type":"string","title":"Model"},"inference_time_ms":{"type":"number","title":"Inference Time Ms"}},"type":"object","required":["class_name","class_id","confidence","all_scores","model","inference_time_ms"],"title":"ClassifyResponse"},"Custom":{"properties":{"input":{"type":"string","title":"Input"},"name":{"type":"string","title":"Name"}},"type":"object","required":["input","name"],"title":"Custom","description":"The custom tool that the model called."},"DetectClassifyRequest":{"properties":{"image":{"type":"string","title":"Image","description":"Base64-encoded image"},"detection_model":{"type":"string","title":"Detection Model","description":"YOLO model for detection","default":"yolov8n"},"classification_model":{"type":"string","title":"Classification Model","description":"CLIP model for classification","default":"clip-vit-base"},"classes":{"items":{"type":"string"},"type":"array","title":"Classes","description":"Classes for zero-shot classification of each crop"},"confidence_threshold":{"type":"number","maximum":1.0,"minimum":0.0,"title":"Confidence Threshold","description":"Detection confidence threshold","default":0.5},"detection_classes":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Detection Classes","description":"Filter detections to these YOLO classes"},"top_k":{"type":"integer","maximum":100.0,"minimum":1.0,"title":"Top K","description":"Top-K classification results per crop","default":3},"min_crop_px":{"type":"integer","minimum":1.0,"title":"Min Crop Px","description":"Minimum crop dimension in pixels (skip tiny detections)","default":16}},"type":"object","required":["image","classes"],"title":"DetectClassifyRequest"},"DetectClassifyResponse":{"properties":{"results":{"items":{"$ref":"#/components/schemas/ClassifiedDetection"},"type":"array","title":"Results"},"total_detections":{"type":"integer","title":"Total Detections"},"classified_count":{"type":"integer","title":"Classified Count"},"detection_model":{"type":"string","title":"Detection Model"},"classification_model":{"type":"string","title":"Classification Model"},"detection_time_ms":{"type":"number","title":"Detection Time Ms"},"classification_time_ms":{"type":"number","title":"Classification Time Ms"},"total_time_ms":{"type":"number","title":"Total Time Ms"}},"type":"object","required":["results","total_detections","classified_count","detection_model","classification_model","detection_time_ms","classification_time_ms","total_time_ms"],"title":"DetectClassifyResponse"},"DetectRequest":{"properties":{"image":{"type":"string","title":"Image","description":"Base64-encoded image"},"model":{"type":"string","title":"Model","default":"yolov8n"},"confidence_threshold":{"type":"number","maximum":1.0,"minimum":0.0,"title":"Confidence Threshold","default":0.5},"classes":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Classes"}},"type":"object","required":["image"],"title":"DetectRequest"},"DetectResponse":{"properties":{"detections":{"items":{"$ref":"#/components/schemas/Detection"},"type":"array","title":"Detections"},"model":{"type":"string","title":"Model"},"inference_time_ms":{"type":"number","title":"Inference Time Ms"}},"type":"object","required":["detections","model","inference_time_ms"],"title":"DetectResponse"},"Detection":{"properties":{"box":{"$ref":"#/components/schemas/BoundingBox"},"class_name":{"type":"string","title":"Class Name"},"class_id":{"type":"integer","title":"Class Id"},"confidence":{"type":"number","title":"Confidence"}},"type":"object","required":["box","class_name","class_id","confidence"],"title":"Detection"},"DetectionItem":{"properties":{"x1":{"type":"number","title":"X1"},"y1":{"type":"number","title":"Y1"},"x2":{"type":"number","title":"X2"},"y2":{"type":"number","title":"Y2"},"class_name":{"type":"string","title":"Class Name"},"class_id":{"type":"integer","title":"Class Id"},"confidence":{"type":"number","title":"Confidence"}},"type":"object","required":["x1","y1","x2","y2","class_name","class_id","confidence"],"title":"DetectionItem"},"File":{"properties":{"file":{"$ref":"#/components/schemas/FileFile"},"type":{"type":"string","const":"file","title":"Type"}},"type":"object","required":["file","type"],"title":"File","description":"Learn about [file inputs](https://platform.openai.com/docs/guides/text) for text generation."},"FileFile":{"properties":{"file_data":{"type":"string","title":"File Data"},"file_id":{"type":"string","title":"File Id"},"filename":{"type":"string","title":"Filename"}},"type":"object","title":"FileFile"},"Function":{"properties":{"arguments":{"type":"string","title":"Arguments"},"name":{"type":"string","title":"Name"}},"type":"object","required":["arguments","name"],"title":"Function","description":"The function that the model called."},"FunctionCall":{"properties":{"arguments":{"type":"string","title":"Arguments"},"name":{"type":"string","title":"Name"}},"type":"object","required":["arguments","name"],"title":"FunctionCall","description":"Deprecated and replaced by `tool_calls`.\n\nThe name and arguments of a function that should be called, as generated by the model."},"FunctionDefinition":{"properties":{"name":{"type":"string","title":"Name"},"description":{"type":"string","title":"Description"},"parameters":{"additionalProperties":true,"type":"object","title":"Parameters"},"strict":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Strict"}},"type":"object","required":["name"],"title":"FunctionDefinition"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"ImageURL":{"properties":{"url":{"type":"string","title":"Url"},"detail":{"type":"string","enum":["auto","low","high"],"title":"Detail"}},"type":"object","required":["url"],"title":"ImageURL"},"InputAudio":{"properties":{"data":{"type":"string","title":"Data"},"format":{"type":"string","enum":["wav","mp3"],"title":"Format"}},"type":"object","required":["data","format"],"title":"InputAudio"},"SessionInfo":{"properties":{"session_id":{"type":"string","title":"Session Id"},"frames_processed":{"type":"integer","title":"Frames Processed"},"actions_triggered":{"type":"integer","title":"Actions Triggered"},"escalations":{"type":"integer","title":"Escalations"},"chain":{"items":{"type":"string"},"type":"array","title":"Chain"},"idle_seconds":{"type":"number","title":"Idle Seconds"},"duration_seconds":{"type":"number","title":"Duration Seconds"}},"type":"object","required":["session_id","frames_processed","actions_triggered","escalations","chain","idle_seconds","duration_seconds"],"title":"SessionInfo"},"SessionsListResponse":{"properties":{"sessions":{"items":{"$ref":"#/components/schemas/SessionInfo"},"type":"array","title":"Sessions"},"count":{"type":"integer","title":"Count"}},"type":"object","required":["sessions","count"],"title":"SessionsListResponse"},"StreamFrameRequest":{"properties":{"session_id":{"type":"string","title":"Session Id"},"image":{"type":"string","title":"Image","description":"Base64-encoded image"}},"type":"object","required":["session_id","image"],"title":"StreamFrameRequest"},"StreamFrameResponse":{"properties":{"status":{"type":"string","title":"Status"},"detections":{"anyOf":[{"items":{"$ref":"#/components/schemas/DetectionItem"},"type":"array"},{"type":"null"}],"title":"Detections"},"confidence":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Confidence"},"resolved_by":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Resolved By"}},"type":"object","required":["status"],"title":"StreamFrameResponse"},"StreamStartRequest":{"properties":{"config":{"$ref":"#/components/schemas/CascadeConfigRequest"},"target_fps":{"type":"number","title":"Target Fps","default":1.0},"action_classes":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Action Classes"},"cooldown_seconds":{"type":"number","title":"Cooldown Seconds","default":5.0}},"type":"object","title":"StreamStartRequest"},"StreamStartResponse":{"properties":{"session_id":{"type":"string","title":"Session Id"}},"type":"object","required":["session_id"],"title":"StreamStartResponse"},"StreamStopRequest":{"properties":{"session_id":{"type":"string","title":"Session Id"}},"type":"object","required":["session_id"],"title":"StreamStopRequest"},"StreamStopResponse":{"properties":{"session_id":{"type":"string","title":"Session Id"},"frames_processed":{"type":"integer","title":"Frames Processed"},"actions_triggered":{"type":"integer","title":"Actions Triggered"},"escalations":{"type":"integer","title":"Escalations"},"duration_seconds":{"type":"number","title":"Duration Seconds"}},"type":"object","required":["session_id","frames_processed","actions_triggered","escalations","duration_seconds"],"title":"StreamStopResponse"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"},"input":{"title":"Input"},"ctx":{"type":"object","title":"Context"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"}}}} \ No newline at end of file diff --git a/runtimes/edge/project.json b/runtimes/edge/project.json new file mode 100644 index 000000000..15d0e66de --- /dev/null +++ b/runtimes/edge/project.json @@ -0,0 +1,31 @@ +{ + "$schema": "../../node_modules/nx/schemas/project-schema.json", + "name": "edge-runtime", + "projectType": "application", + "sourceRoot": "runtimes/edge", + "targets": { + "start": { + "executor": "nx:run-commands", + "options": { + "command": "uv run python server.py", + "cwd": "runtimes/edge" + } + }, + "sync": { + "executor": "nx:run-commands", + "options": { + "commands": [ + "uv python install 3.12", + "uv sync" + ], + "cwd": "runtimes/edge", + "parallel": false + } + } + }, + "tags": [ + "runtime", + "python", + "edge" + ] +} diff --git a/runtimes/edge/pyproject.toml b/runtimes/edge/pyproject.toml new file mode 100644 index 000000000..dd8ae80ff --- /dev/null +++ b/runtimes/edge/pyproject.toml @@ -0,0 +1,90 @@ +[project] +name = "edge-runtime" +version = "0.1.0" +description = "Minimal on-device inference API for Raspberry Pi, Jetson, and edge hardware" +requires-python = ">=3.10,<3.15" +dependencies = [ + # Web framework + "fastapi>=0.104.0", + "uvicorn[standard]>=0.24.0", + "pydantic>=2.0.0", + "python-multipart>=0.0.22", + # OpenAI-compatible types + "openai>=1.0.0", + # Logging + "structlog>=24.0.0", + # Model format detection and GGUF utilities + "huggingface-hub>=0.20.0", + "llamafarm-common", + "llamafarm-llama", + # GGUF metadata reading + "gguf>=0.17.1", + # Model caching + "cachetools>=6.0.0", + # Chat template rendering + "jinja2>=3.0.0", + # Context calculator dependencies + "psutil>=5.9.0", + "pyyaml>=6.0", + "protobuf>=6.33.5,<7", + "sentencepiece>=0.2.1", + # Image processing (for vision) + "pillow>=10.0.0", + "numpy>=1.24.0", + # HTTP client (for streaming cascade) + "httpx>=0.24.0", + # IPC bus (Zenoh pub/sub over Unix socket) + "eclipse-zenoh>=1.0.0", +] + +[project.optional-dependencies] +# Vision models (YOLO + CLIP) — install when edge device has a camera +vision = [ + "ultralytics>=8.4.14", + "transformers>=4.35.0", + "lapx>=0.5.0", +] +# Transformers language models (non-GGUF) +transformers = [ + "transformers>=4.35.0", + "accelerate>=0.25.0", + "torch>=2.6.0", +] +# GPU acceleration +gpu = ["torch>=2.6.0"] + +[dependency-groups] +dev = [ + "ruff>=0.14.3", + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "httpx>=0.24.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["models", "utils", "core", "routers", "services"] +py-modules = ["server"] + +[tool.uv] +environments = [ + "sys_platform == 'linux' and platform_machine == 'aarch64'", + "sys_platform == 'linux' and platform_machine == 'x86_64'", + "sys_platform == 'darwin' and platform_machine == 'arm64'", +] +index-url = "https://pypi.org/simple" + +override-dependencies = [ + "pillow>=10.0.0", + "numpy>=1.24.0,<2.4", +] + +[tool.uv.sources] +llamafarm-common = { path = "../../common", editable = true } +llamafarm-llama = { path = "../../packages/llamafarm-llama", editable = true } + +[tool.ruff] +extend = "../../ruff.toml" diff --git a/runtimes/edge/routers/__init__.py b/runtimes/edge/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/runtimes/edge/routers/cache.py b/runtimes/edge/routers/cache.py new file mode 100644 index 000000000..56664ab1b --- /dev/null +++ b/runtimes/edge/routers/cache.py @@ -0,0 +1,243 @@ +"""KV Cache API — prepare, list, evict, stats, and GC endpoints.""" + +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from services.error_handler import handle_endpoint_errors + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["cache"]) + +# ── Dependency injection ──────────────────────────────────────────────────── + +_cache_manager = None +_load_language_fn = None + + +def set_cache_manager(manager: Any) -> None: + global _cache_manager + _cache_manager = manager + + +def set_cache_language_loader(fn: Any) -> None: + global _load_language_fn + _load_language_fn = fn + + +def _get_manager(): + if _cache_manager is None: + raise HTTPException(500, "KV cache manager not initialized") + return _cache_manager + + +# ── Request/Response models ───────────────────────────────────────────────── + + +MAX_PREPARE_MESSAGES = 200 +MAX_PREPARE_TOOLS = 128 +MAX_MESSAGE_CONTENT_CHARS = 200_000 # ~50k tokens + + +class CachePrepareRequest(BaseModel): + model: str = Field(..., description="Model ID to prepare cache for") + messages: list[dict] = Field( + ..., description="Messages to cache (system prompt, etc)" + ) + tools: list[dict] | None = Field( + None, description="Tool definitions to include" + ) + pinned: bool = Field( + False, description="Pin cache (won't be evicted by GC)" + ) + ttl: float | None = Field( + None, description="TTL in seconds (None = use default)" + ) + warm: bool = Field( + True, + description=( + "If true, loads model and pre-computes KV state " + "(slower but instant cache hits). " + "If false, segment-only indexing." + ), + ) + + +class CachePrepareResponse(BaseModel): + cache_key: str + model: str + token_count: int + size_bytes: int + segments: list[dict] + + +class CacheValidateRequest(BaseModel): + cache_key: str + model: str + messages: list[dict] + tools: list[dict] | None = None + + +class CacheValidateResponse(BaseModel): + status: str # hit, partial_hit, miss + cache_key: str + reusable_tokens: int + invalidated_at: str | None + reason: str + + +# ── Endpoints ──────────────────────────────────────────────────────────────── + + +@router.post("/v1/cache/prepare", response_model=CachePrepareResponse) +@handle_endpoint_errors("cache_prepare") +async def prepare_cache(request: CachePrepareRequest) -> CachePrepareResponse: + """Pre-warm KV cache for a message prefix (system prompt, tools, history). + + Loads the model, tokenizes the messages, runs a forward pass to build KV + state, and serializes it. Returns a cache_key that can be passed to + /v1/chat/completions to skip all prefix processing. + + Use this to pre-warm system prompts, RAG context, or tool definitions + at startup so the first user message gets instant TTFT. + + Set warm=false for lightweight segment-only indexing (no model load). + """ + manager = _get_manager() + + # Input validation + if len(request.messages) > MAX_PREPARE_MESSAGES: + raise HTTPException( + 400, + f"Too many messages ({len(request.messages)}), " + f"max {MAX_PREPARE_MESSAGES}", + ) + if request.tools and len(request.tools) > MAX_PREPARE_TOOLS: + raise HTTPException( + 400, + f"Too many tools ({len(request.tools)}), " + f"max {MAX_PREPARE_TOOLS}", + ) + def _content_chars(content: Any) -> int: + """Count characters in message content, handling multimodal lists.""" + if isinstance(content, str): + return len(content) + if isinstance(content, list): + total = 0 + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "text": + total += len(part.get("text", "")) + elif part.get("type") == "image_url": + # Count base64 data URI size to prevent bypass + image_url = part.get("image_url") + if isinstance(image_url, dict): + total += len(image_url.get("url", "")) + return total + return 0 + + total_chars = sum( + _content_chars(m.get("content")) for m in request.messages + ) + if total_chars > MAX_MESSAGE_CONTENT_CHARS: + raise HTTPException( + 400, + f"Total message content too large ({total_chars} chars), " + f"max {MAX_MESSAGE_CONTENT_CHARS}", + ) + + model = None + if request.warm: + if _load_language_fn is None: + raise HTTPException(500, "Language model loader not configured") + try: + from utils.model_format import parse_model_with_quantization + model_id, quant = parse_model_with_quantization(request.model) + model_wrapper = await _load_language_fn(model_id, preferred_quantization=quant) + # Get the inner Llama instance (not the GGUFLanguageModel wrapper) + model = getattr(model_wrapper, 'llama', model_wrapper) + except Exception as e: + logger.warning(f"Failed to load model for warm prepare: {e}") + # Fall back to segment-only + + entry = await manager.prepare( + model_id=request.model, + messages=request.messages, + tools=request.tools, + pinned=request.pinned, + ttl=request.ttl, + model=model, + ) + + return CachePrepareResponse( + cache_key=entry.cache_key, + model=entry.model_id, + token_count=entry.token_count, + size_bytes=entry.size_bytes, + segments=[{"type": s["type"], "hash": s["hash"]} for s in entry.segments], + ) + + +@router.post("/v1/cache/validate", response_model=CacheValidateResponse) +@handle_endpoint_errors("cache_validate") +async def validate_cache(request: CacheValidateRequest) -> CacheValidateResponse: + """Validate a cache key against a payload without using it. + + Useful for checking if a cache would hit before sending a full request. + """ + manager = _get_manager() + result = manager.validate_and_match( + cache_key=request.cache_key, + model_id=request.model, + messages=request.messages, + tools=request.tools, + ) + return CacheValidateResponse( + status=result["status"], + cache_key=request.cache_key, + reusable_tokens=result["reusable_tokens"], + invalidated_at=result.get("invalidated_at"), + reason=result["reason"], + ) + + +@router.get("/v1/cache") +@handle_endpoint_errors("cache_list") +async def list_caches() -> dict[str, Any]: + """List all cache entries.""" + manager = _get_manager() + entries = manager.list_entries() + return { + "entries": entries, + "count": len(entries), + } + + +@router.get("/v1/cache/stats") +@handle_endpoint_errors("cache_stats") +async def cache_stats() -> dict[str, Any]: + """Get cache statistics — usage, hit rates, tier breakdown.""" + manager = _get_manager() + return manager.get_stats() + + +@router.delete("/v1/cache/{cache_key}") +@handle_endpoint_errors("cache_evict") +async def evict_cache(cache_key: str) -> dict[str, Any]: + """Evict a specific cache entry.""" + manager = _get_manager() + if manager.evict(cache_key): + return {"status": "evicted", "cache_key": cache_key} + raise HTTPException(404, f"Cache entry not found: {cache_key}") + + +@router.post("/v1/cache/gc") +@handle_endpoint_errors("cache_gc") +async def force_gc() -> dict[str, Any]: + """Force garbage collection — removes expired entries.""" + manager = _get_manager() + removed = manager.gc() + return {"status": "ok", "removed": removed} diff --git a/runtimes/edge/routers/chat_completions/__init__.py b/runtimes/edge/routers/chat_completions/__init__.py new file mode 100644 index 000000000..5bc0c2e6d --- /dev/null +++ b/runtimes/edge/routers/chat_completions/__init__.py @@ -0,0 +1,3 @@ +from .router import router + +__all__ = ["router"] diff --git a/runtimes/edge/routers/chat_completions/router.py b/runtimes/edge/routers/chat_completions/router.py new file mode 100644 index 000000000..abbfb2610 --- /dev/null +++ b/runtimes/edge/routers/chat_completions/router.py @@ -0,0 +1,26 @@ +import logging + +from fastapi import APIRouter + +from .service import ChatCompletionsService +from .types import ChatCompletionRequest + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.post("/v1/chat/completions") +async def chat_completions(chat_request: ChatCompletionRequest): + """ + OpenAI-compatible chat completions endpoint. + + Supports any HuggingFace causal language model. + """ + # Debug log the incoming request + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Incoming chat completion request:\n" + f"{chat_request.model_dump_json(indent=2)}" + ) + + return await ChatCompletionsService().chat_completions(chat_request) diff --git a/runtimes/edge/routers/chat_completions/service.py b/runtimes/edge/routers/chat_completions/service.py new file mode 100644 index 000000000..f9846fa9b --- /dev/null +++ b/runtimes/edge/routers/chat_completions/service.py @@ -0,0 +1,1426 @@ +import asyncio +import base64 +import json +import logging +import os +import uuid +from datetime import datetime +from enum import Enum + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_chunk import ( + Choice as ChoiceChunk, +) + +from models import GGUFLanguageModel +from utils.context_manager import ( + ContextBudget, + ContextManager, + ContextUsage, + TruncationStrategy, +) +# Edge runtime: heavy management-plane utilities are optional. +# These are not needed for basic chat completion on edge devices. +try: + from utils.context_summarizer import ContextSummarizer +except ImportError: + ContextSummarizer = None # type: ignore[assignment,misc] + +try: + from utils.history_compressor import HistoryCompressor +except ImportError: + HistoryCompressor = None # type: ignore[assignment,misc] + +try: + from utils.thinking import inject_thinking_control, parse_thinking_response +except ImportError: + from dataclasses import dataclass as _dataclass + + @_dataclass + class _FallbackThinkingResponse: + thinking: str | None + content: str + thinking_complete: bool + + def inject_thinking_control(messages, enable_thinking=False): # type: ignore[misc] + return messages + + def parse_thinking_response(text): # type: ignore[misc] + return _FallbackThinkingResponse(thinking=None, content=text, thinking_complete=True) + +try: + from utils.tool_calling import ( + detect_probable_tool_call, + detect_tool_call_in_content, + extract_arguments_progress, + extract_tool_name_from_partial, + is_tool_call_complete, + parse_tool_choice, + strip_tool_call_from_content, + ) +except ImportError: + # No-op stubs — edge doesn't support tool calling + + def detect_probable_tool_call(*a, **kw): # type: ignore[misc] + return False + + def detect_tool_call_in_content(*a, **kw): # type: ignore[misc] + return None + + def extract_arguments_progress(*a, **kw): # type: ignore[misc] + return "" + + def extract_tool_name_from_partial(*a, **kw): # type: ignore[misc] + return None + + def is_tool_call_complete(*a, **kw): # type: ignore[misc] + return False + + def parse_tool_choice(*a, **kw): # type: ignore[misc] + return ("none", None) + + def strip_tool_call_from_content(*a, **kw): # type: ignore[misc] + return a[0] if a else "" + +from .types import ( + ChatCompletionRequest, + ContextUsageInfo, + ThinkingContent, + extract_audio_from_messages, + has_audio_content, + replace_audio_with_text, +) + + +class ToolCallStreamState(Enum): + """State machine states for incremental tool call streaming.""" + + NORMAL = "normal" # Streaming regular content + BUFFERING_START = "buffering_start" # Detected , waiting for name + STREAMING_ARGS = "streaming_args" # Name emitted, streaming arguments + + +logger = logging.getLogger(__name__) + + +class ChatCompletionsService: + @staticmethod + def _normalize_logprobs_payload(logprobs_payload, top_logprobs: int | None = None): + """Normalize backend logprobs into OpenAI chat choice.logprobs shape.""" + if not isinstance(logprobs_payload, dict): + return None + + # Already OpenAI-style from backend + content = logprobs_payload.get("content") + if isinstance(content, list): + return {"content": content} + + tokens = logprobs_payload.get("tokens") + token_logprobs = logprobs_payload.get("token_logprobs") + top_items = logprobs_payload.get("top_logprobs") + + if not isinstance(tokens, list) or not isinstance(token_logprobs, list): + return None + + normalized = [] + for idx, token in enumerate(tokens): + if not isinstance(token, str): + continue + lp = token_logprobs[idx] if idx < len(token_logprobs) else None + entry = { + "token": token, + "logprob": lp, + "bytes": list(token.encode("utf-8", errors="ignore")) or None, + } + + if isinstance(top_items, list) and idx < len(top_items): + token_top = top_items[idx] + if isinstance(token_top, dict): + pairs = list(token_top.items()) + if top_logprobs is not None: + pairs = pairs[:top_logprobs] + entry["top_logprobs"] = [ + { + "token": str(t), + "logprob": float(v) if v is not None else None, + "bytes": list(str(t).encode("utf-8", errors="ignore")) + or None, + } + for t, v in pairs + if v is not None + ] + normalized.append(entry) + + return {"content": normalized} if normalized else None + + def __init__(self): + # import here to avoid circular import + from server import load_language + + self.load_language = load_language + + _cache_manager = None + + @classmethod + def set_cache_manager(cls, manager): + cls._cache_manager = manager + + @classmethod + def _get_cache_manager(cls): + return cls._cache_manager + + async def _transcribe_audio(self, audio_data: bytes, audio_format: str = "wav") -> str: + """Transcribe audio using the STT model. + + This is used as a fallback when the LLM doesn't support direct audio input. + + Args: + audio_data: Base64-decoded audio bytes + audio_format: Audio format (wav, mp3, pcm) + + Returns: + Transcribed text + """ + from server import load_speech + + # Load STT model (default whisper model) + stt_model = await load_speech() + + # Convert audio format if needed + if audio_format == "pcm": + # Convert PCM to WAV for whisper + import io + import wave + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(16000) + wav_file.writeframes(audio_data) + audio_data = wav_buffer.getvalue() + + # Transcribe + result = await stt_model.transcribe_audio(audio_data) + return result.get("text", "").strip() + + async def chat_completions(self, chat_request: ChatCompletionRequest): + """ + Chat completions service. + """ + + try: + # Import parsing utility + from utils.model_format import parse_model_with_quantization + + # Get GGUF-specific parameters from request + n_ctx = chat_request.n_ctx + n_batch = chat_request.n_batch + n_gpu_layers = chat_request.n_gpu_layers + n_threads = chat_request.n_threads + flash_attn = chat_request.flash_attn + use_mmap = chat_request.use_mmap + use_mlock = chat_request.use_mlock + cache_type_k = chat_request.cache_type_k + cache_type_v = chat_request.cache_type_v + + # Also check extra_body for these parameters (OpenAI SDK sends custom params there) + if chat_request.extra_body: + if n_ctx is None and "n_ctx" in chat_request.extra_body: + n_ctx = chat_request.extra_body.get("n_ctx") + if n_batch is None and "n_batch" in chat_request.extra_body: + n_batch = chat_request.extra_body.get("n_batch") + if n_gpu_layers is None and "n_gpu_layers" in chat_request.extra_body: + n_gpu_layers = chat_request.extra_body.get("n_gpu_layers") + if n_threads is None and "n_threads" in chat_request.extra_body: + n_threads = chat_request.extra_body.get("n_threads") + if flash_attn is None and "flash_attn" in chat_request.extra_body: + flash_attn = chat_request.extra_body.get("flash_attn") + if use_mmap is None and "use_mmap" in chat_request.extra_body: + use_mmap = chat_request.extra_body.get("use_mmap") + if use_mlock is None and "use_mlock" in chat_request.extra_body: + use_mlock = chat_request.extra_body.get("use_mlock") + if cache_type_k is None and "cache_type_k" in chat_request.extra_body: + cache_type_k = chat_request.extra_body.get("cache_type_k") + if cache_type_v is None and "cache_type_v" in chat_request.extra_body: + cache_type_v = chat_request.extra_body.get("cache_type_v") + + # Parse model name to extract quantization if present + model_id, gguf_quantization = parse_model_with_quantization( + chat_request.model + ) + + # Convert messages to dict format early (needed for audio detection) + messages_dict = [dict(msg) for msg in chat_request.messages] + + # Check for audio content in messages + audio_in_request = has_audio_content(messages_dict) + + # Extract thinking params from extra_body if not set at top level + # (OpenAI SDK sends custom params via extra_body) + think_param = chat_request.think + thinking_budget_param = chat_request.thinking_budget + if chat_request.extra_body: + if think_param is None and "think" in chat_request.extra_body: + think_param = chat_request.extra_body.get("think") + if ( + thinking_budget_param is None + and "thinking_budget" in chat_request.extra_body + ): + thinking_budget_param = chat_request.extra_body.get( + "thinking_budget" + ) + + # Convert tools to dict format if provided (for streaming) + tools_dict = None + if chat_request.tools: + tools_dict = [dict(tool) for tool in chat_request.tools] + tools_for_generation = tools_dict + + async def prepare_generation(): + nonlocal tools_for_generation + model = await self.load_language( + model_id, + n_ctx=n_ctx, + n_batch=n_batch, + n_gpu_layers=n_gpu_layers, + n_threads=n_threads, + flash_attn=flash_attn, + use_mmap=use_mmap, + use_mlock=use_mlock, + cache_type_k=cache_type_k, + cache_type_v=cache_type_v, + preferred_quantization=gguf_quantization, + ) + + # Check if this is a GGUF model - use native chat completion for proper template + # GGUF models have create_chat_completion() which uses the embedded chat template + # This is essential for models like Qwen that use special tokens (<|im_start|>, etc.) + # and thinking tags () + is_gguf = isinstance(model, GGUFLanguageModel) + + # Handle audio content - either native audio or STT transcription + use_native_audio = False + audio_bytes = None + audio_format = "wav" + prepared_messages = messages_dict + + if audio_in_request: + # Check if model supports native audio input + model_supports_audio = is_gguf and getattr( + model, "supports_audio", False + ) + + if model_supports_audio: + # Use native audio input (no transcription needed) + logger.info( + "Model supports native audio input - using direct audio processing" + ) + use_native_audio = True + + # Extract audio data (only first audio part for now) + audio_parts = extract_audio_from_messages(prepared_messages) + if audio_parts: + _, audio_input = audio_parts[0] + audio_bytes = base64.b64decode(audio_input.data) + audio_format = audio_input.format + logger.info( + f"Using native audio: {len(audio_bytes)} bytes, format={audio_format}" + ) + else: + # Fall back to STT transcription + logger.info( + "Audio content detected - transcribing via STT (model doesn't support native audio)" + ) + + # Extract and transcribe all audio parts + audio_parts = extract_audio_from_messages(prepared_messages) + transcriptions: dict[int, str] = {} + + for msg_idx, audio_input in audio_parts: + # Decode base64 audio + audio_bytes_for_stt = base64.b64decode(audio_input.data) + # Transcribe + transcription = await self._transcribe_audio( + audio_bytes_for_stt, audio_input.format + ) + transcriptions[msg_idx] = transcription + logger.debug( + f"Transcribed audio in message {msg_idx}: " + f"'{transcription[:100]}{'...' if len(transcription) > 100 else ''}'" + ) + + # Replace audio content with transcribed text + prepared_messages = replace_audio_with_text( + prepared_messages, transcriptions + ) + logger.info( + f"Replaced {len(audio_parts)} audio parts with transcriptions" + ) + + # Inject thinking control (Qwen soft switch: /think or /no_think) + # Default is OFF - inject /no_think unless explicitly enabled with think=true + if is_gguf: + # think=True -> enable, think=False or None -> disable + enable_thinking = think_param is True + prepared_messages = inject_thinking_control( + prepared_messages, enable_thinking=enable_thinking + ) + logger.info( + f"Thinking mode {'enabled' if enable_thinking else 'disabled'} via soft switch" + ) + + # Calculate total token budget for generation + # - max_tokens: for the final answer (default: 512) + # - thinking_budget: for the thinking process (default: 1024 if thinking enabled) + # Total = thinking_budget + max_tokens (so answer isn't cut short by thinking) + answer_tokens = chat_request.max_tokens or 512 + + # Determine if thinking is enabled (default: OFF for predictable behavior) + # User must explicitly set think=true to enable thinking mode + thinking_enabled = think_param is True + + if thinking_enabled and is_gguf: + # Use provided thinking_budget or default to 1024 + thinking_tokens = thinking_budget_param or 1024 + total_max_tokens = thinking_tokens + answer_tokens + logger.info( + f"Token allocation: {thinking_tokens} for thinking + {answer_tokens} for answer = {total_max_tokens} total" + ) + else: + # No thinking, just use answer tokens + total_max_tokens = answer_tokens + thinking_tokens = 0 + + # Context management for GGUF models + context_usage_info = None + effective_max_tokens = total_max_tokens + + if is_gguf and model.context_manager: + context_manager = model.context_manager + + # Build a request-aware budget so context checks reserve the same + # completion budget we intend to generate (answer + thinking). + if model.token_counter: + base_budget = context_manager.budget + context_manager = ContextManager( + model.token_counter, + ContextBudget.from_context_size( + base_budget.total_context, + max_completion_tokens=total_max_tokens, + ), + ) + + # Apply history compression to reduce token usage + if HistoryCompressor is not None: + compressor = HistoryCompressor(model.token_counter) + prepared_messages = compressor.compress(prepared_messages) + + # If tools are injected into the prompt path, validate against the same + # message shape to avoid undercounting prompt tokens. + messages_for_context = prepared_messages + tools_already_injected = False + native_rendered_prompt: str | None = None + if tools_dict: + ( + messages_for_context, + tools_already_injected, + native_rendered_prompt, + ) = model.prepare_messages_for_context_validation( + prepared_messages, + tools_dict, + chat_request.tool_choice, + ) + if tools_already_injected: + prepared_messages = messages_for_context + tools_for_generation = None + + # Validate context and truncate if needed + if native_rendered_prompt is not None: + if model.token_counter is None: + raise HTTPException( + status_code=400, + detail={ + "error": "context_validation_unavailable", + "message": ( + "Unable to validate native-rendered prompt context " + "because token counting is unavailable." + ), + }, + ) + prompt_tokens = model.token_counter.count_tokens( + native_rendered_prompt + ) + available_for_completion = max( + 0, + context_manager.budget.total_context + - prompt_tokens + - context_manager.budget.safety_margin, + ) + usage = ContextUsage( + total_context=context_manager.budget.total_context, + prompt_tokens=prompt_tokens, + available_for_completion=available_for_completion, + truncated=False, + truncated_messages=0, + strategy_used=None, + ) + else: + usage = context_manager.validate_messages(messages_for_context) + + if usage.prompt_tokens > context_manager.budget.max_prompt_tokens: + auto_truncate = chat_request.auto_truncate + if auto_truncate is None: + auto_truncate = True # Default to auto-truncate + + if not auto_truncate: + raise HTTPException( + status_code=400, + detail={ + "error": "context_length_exceeded", + "message": ( + f"Prompt ({usage.prompt_tokens} tokens) exceeds " + f"context limit ({usage.total_context} tokens). " + "Set auto_truncate=true to automatically truncate." + ), + "context_usage": { + "total_context": usage.total_context, + "prompt_tokens": usage.prompt_tokens, + "available_for_completion": usage.available_for_completion, + }, + }, + ) + + # Native Jinja2 rendering produces a single raw prompt string. + # We cannot safely truncate it with message-based strategies. + if native_rendered_prompt is not None: + raise HTTPException( + status_code=400, + detail={ + "error": "context_length_exceeded", + "message": ( + f"Rendered prompt ({usage.prompt_tokens} tokens) exceeds " + f"context limit ({usage.total_context} tokens). " + "Reduce message/tool size and retry." + ), + "context_usage": { + "total_context": usage.total_context, + "prompt_tokens": usage.prompt_tokens, + "available_for_completion": usage.available_for_completion, + }, + }, + ) + + # Determine truncation strategy + strategy = None + if chat_request.truncation_strategy: + try: + strategy = TruncationStrategy( + chat_request.truncation_strategy + ) + except ValueError: + logger.warning( + f"Unknown truncation strategy: {chat_request.truncation_strategy}, " + "using default (summarize)" + ) + strategy = TruncationStrategy.SUMMARIZE + else: + strategy = TruncationStrategy.SUMMARIZE # Default + + # Sliding-window can drop injected tool instructions (often in + # the first system message). Preserve system messages in this case. + if ( + tools_already_injected + and strategy == TruncationStrategy.SLIDING_WINDOW + ): + logger.info( + "Switching truncation strategy from sliding_window to " + "keep_system to preserve injected tool definitions" + ) + strategy = TruncationStrategy.KEEP_SYSTEM_SLIDING + + # Handle summarization strategy (async, needs special handling) + if strategy == TruncationStrategy.SUMMARIZE: + try: + # Pass the server's load_language for proper caching + summarizer = ContextSummarizer( + load_language=self.load_language + ) + messages_for_context = await summarizer.summarize_messages( + messages_for_context + ) + # Re-validate after summarization + usage = context_manager.validate_messages( + messages_for_context + ) + + # Check if we STILL need truncation after summarization + # (e.g., if recent messages are still too large) + if context_manager.needs_truncation(messages_for_context): + logger.warning( + f"Still over budget after summarization " + f"({usage.prompt_tokens} tokens), applying fallback truncation" + ) + messages_for_context, usage = ( + context_manager.truncate_if_needed( + messages_for_context, + TruncationStrategy.KEEP_SYSTEM_SLIDING, + ) + ) + usage = type(usage)( + total_context=usage.total_context, + prompt_tokens=usage.prompt_tokens, + available_for_completion=usage.available_for_completion, + truncated=True, + truncated_messages=usage.truncated_messages, + strategy_used="summarize+keep_system", + ) + else: + usage = type(usage)( + total_context=usage.total_context, + prompt_tokens=usage.prompt_tokens, + available_for_completion=usage.available_for_completion, + truncated=True, + truncated_messages=0, # Summarized, not removed + strategy_used="summarize", + ) + logger.info( + f"Context summarized: {usage.prompt_tokens} tokens after summarization" + ) + except Exception as e: + logger.warning( + f"Summarization failed: {e}, falling back to keep_system" + ) + messages_for_context, usage = ( + context_manager.truncate_if_needed( + messages_for_context, + TruncationStrategy.KEEP_SYSTEM_SLIDING, + ) + ) + else: + # Use regular truncation strategy + messages_for_context, usage = context_manager.truncate_if_needed( + messages_for_context, strategy + ) + logger.info( + f"Context truncated: {usage.truncated_messages} messages removed, " + f"strategy={usage.strategy_used}" + ) + + # Use the validated/truncated message set for generation. + if native_rendered_prompt is None: + prepared_messages = messages_for_context + + # Track the true remaining completion budget (not reserved target). + real_available_for_completion = max( + 0, + context_manager.budget.total_context + - usage.prompt_tokens + - context_manager.budget.safety_margin, + ) + + # Store context usage for response + context_usage_info = ContextUsageInfo( + total_context=usage.total_context, + prompt_tokens=usage.prompt_tokens, + available_for_completion=real_available_for_completion, + truncated=usage.truncated, + truncated_messages=usage.truncated_messages, + strategy_used=usage.strategy_used, + ) + + # Final safety check: ensure we're actually under budget + if ( + native_rendered_prompt is None + and context_manager.needs_truncation(prepared_messages) + ): + final_usage = context_manager.validate_messages(prepared_messages) + logger.error( + f"CRITICAL: Still over context budget after all truncation: " + f"{final_usage.prompt_tokens} tokens > " + f"{context_manager.budget.max_prompt_tokens} max" + ) + raise HTTPException( + status_code=400, + detail={ + "error": "context_truncation_failed", + "message": ( + f"Failed to reduce context to fit within budget. " + f"Current: {final_usage.prompt_tokens} tokens, " + f"Max: {context_manager.budget.max_prompt_tokens} tokens. " + "Try sending fewer or shorter messages." + ), + "context_usage": { + "total_context": final_usage.total_context, + "prompt_tokens": final_usage.prompt_tokens, + "available_for_completion": final_usage.available_for_completion, + }, + }, + ) + + # Cap generation to what can actually fit after prompt accounting. + effective_max_tokens = min( + total_max_tokens, real_available_for_completion + ) + if effective_max_tokens <= 0: + raise HTTPException( + status_code=400, + detail={ + "error": "context_length_exceeded", + "message": ( + "No completion budget remains after prompt allocation. " + "Try sending fewer or shorter messages." + ), + "context_usage": context_usage_info.model_dump(), + }, + ) + + return ( + model, + is_gguf, + prepared_messages, + use_native_audio, + audio_bytes, + audio_format, + total_max_tokens, + effective_max_tokens, + thinking_tokens, + context_usage_info, + ) + # Handle streaming if requested + if chat_request.stream: + logger.info( + f"Streaming chat completions for model: {chat_request.model}" + ) + + ( + model, + is_gguf, + prepared_messages, + use_native_audio, + audio_bytes, + audio_format, + total_max_tokens, + effective_max_tokens, + thinking_tokens, + context_usage_info, + ) = await prepare_generation() + + # ── KV Cache: check for cache hit (streaming) ──────────── + _stream_kv_data = None + _stream_kv_tokens = 0 + _stream_cache_info = None + _s_return_cache_key = None + _stream_cache_manager = self._get_cache_manager() + if _stream_cache_manager and is_gguf: + _s_cache_key = chat_request.cache_key + if _s_cache_key is None and chat_request.extra_body: + _s_cache_key = chat_request.extra_body.get("cache_key") + + _s_return_cache_key = chat_request.return_cache_key + if _s_return_cache_key is None and chat_request.extra_body: + _s_return_cache_key = chat_request.extra_body.get("return_cache_key") + + if _s_cache_key: + match = _stream_cache_manager.validate_and_match( + cache_key=_s_cache_key, + model_id=chat_request.model, + messages=messages_dict, + tools=tools_dict, + ) + if match["status"] == "hit" and match["entry"]: + entry = match["entry"] + kv_data = entry.kv_data + if not kv_data and entry.disk_path: + from pathlib import Path as _Path + dp = _Path(entry.disk_path) + if dp.exists(): + kv_data = dp.read_bytes() + if kv_data: + _stream_kv_data = kv_data + _stream_kv_tokens = entry.token_count + entry.touch() + _stream_cache_info = { + "hit": True, "status": "hit", + "cache_key": _s_cache_key, + "reused_tokens": entry.token_count, + "has_kv_data": bool(kv_data), + } + logger.info(f"KV cache hit (streaming): {_s_cache_key[:8]}…, kv_data={'yes' if kv_data else 'no'}") + else: + _stream_cache_info = { + "hit": False, "status": match["status"], + "cache_key": _s_cache_key, + "reason": match.get("reason"), + } + + # Return SSE stream + async def generate_sse(): + completion_id = f"chatcmpl-{os.urandom(16).hex()}" + created_time = int(datetime.now().timestamp()) + + # Send initial chunk + initial_chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta(role="assistant", content=""), + finish_reason=None, + ) + ], + ) + yield f"data: {initial_chunk.model_dump_json(exclude_none=True)}\n\n".encode() + # Force an immediate flush before any model loading. + await asyncio.sleep(0) + + # Stream tokens - use native audio if supported, otherwise text + if use_native_audio and audio_bytes: + # Use native audio processing (no STT transcription) + token_stream = model.generate_stream_with_audio( + messages=prepared_messages, + audio_data=audio_bytes, + audio_format=audio_format, + max_tokens=effective_max_tokens, + temperature=chat_request.temperature + if chat_request.temperature is not None + else 0.7, + top_p=chat_request.top_p, + stop=chat_request.stop, + ) + else: + # Standard text generation (audio already transcribed if present) + token_stream = model.generate_stream( + messages=prepared_messages, + max_tokens=effective_max_tokens, + temperature=chat_request.temperature + if chat_request.temperature is not None + else 0.7, + top_p=chat_request.top_p, + stop=chat_request.stop, + thinking_budget=(thinking_tokens or None) if is_gguf else None, + tools=tools_for_generation, + tool_choice=chat_request.tool_choice, + kv_cache_data=_stream_kv_data, + kv_cache_tokens=_stream_kv_tokens, + ) + + # State machine for incremental tool call streaming + accumulated_content = "" + tool_state = ToolCallStreamState.NORMAL + buffered_tokens = [] + tool_call_id = None + tool_call_index = 0 + args_emitted_length = 0 + any_tool_calls_emitted = False # Track if we emitted any tool calls + + # Parse tool_choice to determine if we should detect tool calls + # When tool_choice="none", we skip tool detection entirely + tool_choice_mode, _ = parse_tool_choice(chat_request.tool_choice) + should_detect_tools = tools_dict and tool_choice_mode != "none" + + async for token in token_stream: + accumulated_content += token + + # STATE: NORMAL - streaming regular content + if tool_state == ToolCallStreamState.NORMAL: + # Check if we're entering a tool call + if should_detect_tools and detect_probable_tool_call( + accumulated_content + ): + tool_state = ToolCallStreamState.BUFFERING_START + buffered_tokens.append(token) + continue + + # Normal content streaming + chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", content=token + ), + finish_reason=None, + ) + ], + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n".encode() + # CRITICAL: This asyncio.sleep(0) forces the event loop + # to yield, ensuring token-by-token delivery. + await asyncio.sleep(0) + + # STATE: BUFFERING_START - waiting for tool name + elif tool_state == ToolCallStreamState.BUFFERING_START: + buffered_tokens.append(token) + + # Try to extract tool name + tool_name = extract_tool_name_from_partial( + accumulated_content + ) + if tool_name: + # Emit initial tool call chunk with name + tool_call_id = f"call_{uuid.uuid4()}" + initial_tool_chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=tool_call_index, + id=tool_call_id, + type="function", + function=ChoiceDeltaToolCallFunction( + name=tool_name, + arguments="", + ), + ) + ] + ), + finish_reason=None, + ) + ], + ) + yield f"data: {initial_tool_chunk.model_dump_json(exclude_none=True)}\n\n".encode() + await asyncio.sleep(0) + + tool_state = ToolCallStreamState.STREAMING_ARGS + args_emitted_length = 0 + logger.info( + f"Tool call started: {tool_name} (id={tool_call_id})" + ) + + # STATE: STREAMING_ARGS - incrementally streaming arguments + elif tool_state == ToolCallStreamState.STREAMING_ARGS: + # Check if tool call is complete + if is_tool_call_complete(accumulated_content): + # Parse the complete tool call to get final arguments + # We only want the FIRST complete tool call in accumulated_content + tool_calls = detect_tool_call_in_content( + accumulated_content + ) + if tool_calls: + _, final_args = tool_calls[0] + + # Emit remaining arguments (from where we left off) + if len(final_args) > args_emitted_length: + remaining_args = final_args[ + args_emitted_length: + ] + args_chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=tool_call_index, + function=ChoiceDeltaToolCallFunction( + arguments=remaining_args, + ), + ) + ] + ), + finish_reason=None, + ) + ], + ) + yield f"data: {args_chunk.model_dump_json(exclude_none=True)}\n\n".encode() + await asyncio.sleep(0) + + # Log the completed tool call + if tool_calls: + tool_name_completed, tool_args = tool_calls[0] + logger.info( + f"Tool call completed: {tool_name_completed} " + f"(id={tool_call_id}, args={tool_args[:100]}{'...' if len(tool_args) > 100 else ''})" + ) + + # Mark that we've emitted at least one tool call + any_tool_calls_emitted = True + + # Reset state machine for potential next tool call + # Strip the completed tool call from accumulated_content + accumulated_content = strip_tool_call_from_content( + accumulated_content + ) + tool_state = ToolCallStreamState.NORMAL + buffered_tokens = [] + tool_call_id = None + tool_call_index += 1 + args_emitted_length = 0 + + # Check if there's already another tool call starting + # in the remaining content + if should_detect_tools and detect_probable_tool_call( + accumulated_content + ): + tool_state = ToolCallStreamState.BUFFERING_START + + # Continue processing - don't return yet + continue + + # Try to extract arguments progress + args_progress = extract_arguments_progress( + accumulated_content + ) + if args_progress: + _, current_args = args_progress + # Emit new argument characters + if len(current_args) > args_emitted_length: + new_args = current_args[args_emitted_length:] + args_chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=tool_call_index, + function=ChoiceDeltaToolCallFunction( + arguments=new_args, + ), + ) + ] + ), + finish_reason=None, + ) + ], + ) + yield f"data: {args_chunk.model_dump_json(exclude_none=True)}\n\n".encode() + await asyncio.sleep(0) + args_emitted_length = len(current_args) + + # Handle incomplete tool calls at stream end + if ( + tool_state != ToolCallStreamState.NORMAL + and buffered_tokens + and not is_tool_call_complete(accumulated_content) + ): + # Emit buffered tokens as regular content + for buffered_token in buffered_tokens: + chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta(content=buffered_token), + finish_reason=None, + ) + ], + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n".encode() + await asyncio.sleep(0) + + # Debug log the accumulated streaming response + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Streaming response complete ({len(accumulated_content)} chars):\n" + f"{accumulated_content}" + ) + + # Send final chunk with appropriate finish_reason + # If we emitted any tool calls, use "tool_calls", otherwise "stop" + finish_reason = "tool_calls" if any_tool_calls_emitted else "stop" + final_chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created_time, + model=chat_request.model, + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta(), + finish_reason=finish_reason, + ) + ], + ) + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n".encode() + await asyncio.sleep(0) + + # ── KV Cache: save post-generation state (streaming) ── + if _stream_cache_manager and is_gguf and (_s_return_cache_key or _stream_cache_info): + try: + full_msgs = list(messages_dict) + [ + {"role": "assistant", "content": accumulated_content} + ] + new_entry = await _stream_cache_manager.save_after_generation( + model=model.llama, + model_id=chat_request.model, + parent_key=chat_request.cache_key, + messages=full_msgs, + tools=tools_dict, + prompt_tokens=context_usage_info.prompt_tokens if context_usage_info else 0, + ) + cache_event = dict(_stream_cache_info) if _stream_cache_info else {} + cache_event["new_cache_key"] = new_entry.cache_key + cache_event["cached_tokens"] = new_entry.token_count + # Use a named SSE event type so OpenAI SDK clients + # ignore it (they only process default "message" events) + yield f"event: x_cache\ndata: {json.dumps(cache_event)}\n\n".encode() + await asyncio.sleep(0) + except Exception as e: + logger.warning(f"Failed to save streaming post-gen cache: {e}", exc_info=True) + + yield b"data: [DONE]\n\n" + + return StreamingResponse( + generate_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + # Non-streaming response - use native audio if supported, otherwise text + ( + model, + is_gguf, + prepared_messages, + use_native_audio, + audio_bytes, + audio_format, + total_max_tokens, + effective_max_tokens, + thinking_tokens, + context_usage_info, + ) = await prepare_generation() + + response_logprobs = None + + # ── KV Cache: check for cache hit ──────────────────────────────── + cache_info = None + return_cache_key = None + _kv_cache_data = None + _kv_cache_tokens = 0 + cache_manager = self._get_cache_manager() + if cache_manager and is_gguf: + import time as _time + _cache_start = _time.time() + + cache_key = chat_request.cache_key + if cache_key is None and chat_request.extra_body: + cache_key = chat_request.extra_body.get("cache_key") + + return_cache_key = chat_request.return_cache_key + if return_cache_key is None and chat_request.extra_body: + return_cache_key = chat_request.extra_body.get("return_cache_key") + + if cache_key: + match = cache_manager.validate_and_match( + cache_key=cache_key, + model_id=chat_request.model, + messages=messages_dict, + tools=tools_dict, + ) + if match["status"] == "hit" and match["entry"]: + entry = match["entry"] + # Load KV data for restore (from ram or disk) + kv_data = entry.kv_data + if not kv_data and entry.disk_path: + from pathlib import Path as _Path + dp = _Path(entry.disk_path) + if dp.exists(): + kv_data = dp.read_bytes() + if kv_data: + _kv_cache_data = kv_data + _kv_cache_tokens = entry.token_count + entry.touch() + cache_info = { + "hit": True, + "status": "hit", + "cache_key": cache_key, + "reused_tokens": entry.token_count, + "has_kv_data": bool(kv_data), + "time_saved_ms": round((_time.time() - _cache_start) * 1000, 2), + } + logger.info( + f"KV cache hit: {cache_key[:8]}…, " + f"{entry.token_count} tokens, " + f"kv_data={'yes' if kv_data else 'no'}" + ) + elif match["status"] == "partial_hit": + cache_info = { + "hit": False, + "status": "partial_hit", + "cache_key": cache_key, + "reused_tokens": match["reusable_tokens"], + "invalidated_at": match.get("invalidated_at"), + "reason": match["reason"], + } + else: + cache_info = { + "hit": False, + "status": "miss", + "cache_key": cache_key, + "reused_tokens": 0, + "reason": match["reason"], + } + + if use_native_audio and audio_bytes: + # Use native audio processing (no STT transcription) + response_text = await model.generate_with_audio( + messages=prepared_messages, + audio_data=audio_bytes, + audio_format=audio_format, + max_tokens=effective_max_tokens, + temperature=chat_request.temperature + if chat_request.temperature is not None + else 0.7, + top_p=chat_request.top_p, + stop=chat_request.stop, + ) + else: + # Standard text generation (audio already transcribed if present) + if is_gguf and chat_request.logprobs: + detailed = await model.generate_with_logprobs( + messages=prepared_messages, + max_tokens=effective_max_tokens, + temperature=chat_request.temperature + if chat_request.temperature is not None + else 0.7, + top_p=chat_request.top_p, + stop=chat_request.stop, + thinking_budget=(thinking_tokens or None), + tools=tools_for_generation, + tool_choice=chat_request.tool_choice, + top_logprobs=chat_request.top_logprobs, + kv_cache_data=_kv_cache_data, + kv_cache_tokens=_kv_cache_tokens, + ) + response_text = detailed.get("content", "") + response_logprobs = detailed.get("logprobs") + else: + response_text = await model.generate( + messages=prepared_messages, + max_tokens=effective_max_tokens, + temperature=chat_request.temperature + if chat_request.temperature is not None + else 0.7, + top_p=chat_request.top_p, + stop=chat_request.stop, + thinking_budget=(thinking_tokens or None) if is_gguf else None, + tools=tools_for_generation, + tool_choice=chat_request.tool_choice, + kv_cache_data=_kv_cache_data, + kv_cache_tokens=_kv_cache_tokens, + ) + + # Debug log the raw response from the model + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Model raw response ({len(response_text)} chars):\n{response_text}" + ) + + # Parse thinking content from response (like Ollama does) + # This separates ... into a separate field + parsed = parse_thinking_response(response_text) + + # Check for tool calls in response (only if tools were provided and tool_choice != "none") + # This is consistent with streaming path which only checks when tools are enabled + tool_calls = None + tool_choice_mode, _ = parse_tool_choice(chat_request.tool_choice) + if tools_dict and tool_choice_mode != "none": + tool_calls = detect_tool_call_in_content(parsed.content) + + normalized_logprobs = self._normalize_logprobs_payload( + response_logprobs, chat_request.top_logprobs + ) + + if tool_calls: + # Log detected tool calls + for name, args in tool_calls: + logger.info( + f"Tool call detected: {name} " + f"(args={args[:100]}{'...' if len(args) > 100 else ''})" + ) + + # Build response with tool calls + prompt_tokens = ( + context_usage_info.prompt_tokens if context_usage_info else 0 + ) + response = { + "id": f"chatcmpl-{os.urandom(16).hex()}", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": chat_request.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": f"call_{uuid.uuid4()}", + "type": "function", + "function": { + "name": name, + "arguments": args, + }, + } + for name, args in tool_calls + ], + }, + "finish_reason": "tool_calls", + **({"logprobs": normalized_logprobs} if chat_request.logprobs else {}), + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": 0, # TODO: count completion tokens + "total_tokens": prompt_tokens, + }, + } + # Add context usage info if available + if context_usage_info: + response["x_context_usage"] = context_usage_info.model_dump() + + # Debug log the response with tool calls + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Sending response with tool calls:\n" + f"{json.dumps(response, indent=2, default=str)}" + ) + + # ── KV Cache: save after tool-call generation ──────────────── + if cache_manager and is_gguf and (return_cache_key or cache_info): + try: + # Strip tool call markup from content for cache + clean_content = strip_tool_call_from_content(response_text) + full_messages = list(messages_dict) + [ + {"role": "assistant", "content": clean_content} + ] + _prompt_tokens = ( + context_usage_info.prompt_tokens if context_usage_info else 0 + ) + new_entry = await cache_manager.save_after_generation( + model=model.llama, + model_id=chat_request.model, + parent_key=chat_request.cache_key, + messages=full_messages, + tools=tools_dict, + prompt_tokens=_prompt_tokens, + ) + if cache_info is None: + cache_info = {} + cache_info["new_cache_key"] = new_entry.cache_key + cache_info["cached_tokens"] = new_entry.token_count + except Exception as e: + logger.warning(f"Failed to save tool-call post-gen cache: {e}") + + if cache_info: + response["x_cache"] = cache_info + + return response + + # Build response with optional thinking field (Ollama-compatible) + prompt_tokens = ( + context_usage_info.prompt_tokens if context_usage_info else 0 + ) + response = { + "id": f"chatcmpl-{os.urandom(16).hex()}", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": chat_request.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": parsed.content}, + "finish_reason": "stop", + **({"logprobs": normalized_logprobs} if chat_request.logprobs else {}), + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": 0, # TODO: count completion tokens + "total_tokens": prompt_tokens, + }, + } + + # Add thinking field if present (Ollama-compatible) + if parsed.thinking: + response["thinking"] = ThinkingContent( + content=parsed.thinking, + tokens=None, # TODO: count thinking tokens + ).model_dump() + + # Add context usage info if available + if context_usage_info: + response["x_context_usage"] = context_usage_info.model_dump() + + # ── KV Cache: save post-generation state ──────────────────────── + if cache_manager and is_gguf and (return_cache_key or cache_info): + try: + # Build full conversation including the response + # Use messages_dict (original request messages) not prepared_messages + # to avoid segment hash drift from inject_thinking_control + full_messages = list(messages_dict) + [ + {"role": "assistant", "content": parsed.content} + ] + # Get exact prompt token count for KV restore accuracy + _prompt_tokens = ( + context_usage_info.prompt_tokens if context_usage_info else 0 + ) + new_entry = await cache_manager.save_after_generation( + model=model.llama, + model_id=chat_request.model, + parent_key=chat_request.cache_key, + messages=full_messages, + tools=tools_dict, + prompt_tokens=_prompt_tokens, + ) + if cache_info is None: + cache_info = {} + cache_info["new_cache_key"] = new_entry.cache_key + cache_info["cached_tokens"] = new_entry.token_count + except Exception as e: + logger.warning(f"Failed to save post-generation cache: {e}") + + # Add cache info to response + if cache_info: + response["x_cache"] = cache_info + + # Debug log the response + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Sending response:\n{json.dumps(response, indent=2, default=str)}" + ) + + return response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in chat_completions: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/runtimes/edge/routers/chat_completions/types.py b/runtimes/edge/routers/chat_completions/types.py new file mode 100644 index 000000000..7e699a6a2 --- /dev/null +++ b/runtimes/edge/routers/chat_completions/types.py @@ -0,0 +1,307 @@ +from typing import Literal + +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam +from pydantic import BaseModel, Field + +# ============================================================================ +# Audio Content Types (for STT transcription) +# ============================================================================ + + +class InputAudio(BaseModel): + """Audio data for input_audio content parts. + + Audio content is automatically transcribed via STT before LLM processing. + """ + + data: str = Field(..., description="Base64-encoded audio data") + format: Literal["wav", "mp3", "pcm"] = Field( + default="wav", description="Audio format (wav recommended for best compatibility)" + ) + + +class AudioContentPart(BaseModel): + """Audio content part for messages with audio. + + Audio is automatically transcribed via STT and the text is passed to the LLM. + """ + + type: Literal["input_audio"] = "input_audio" + input_audio: InputAudio + + +class TextContentPart(BaseModel): + """Text content part for messages.""" + + type: Literal["text"] = "text" + text: str + + +# Union type for content parts in messages (text, audio, etc.) +ContentPart = AudioContentPart | TextContentPart | dict + + +# ============================================================================ +# Tool Calling Types +# ============================================================================ + + +class FunctionCall(BaseModel): + """Function call details within a tool call.""" + + name: str + arguments: str # JSON string of arguments + + +class ToolCall(BaseModel): + """A tool call made by the assistant.""" + + id: str + type: Literal["function"] = "function" + function: FunctionCall + + +class ChatCompletionRequest(BaseModel): + """OpenAI-compatible chat completion request.""" + + model: str + messages: list[ChatCompletionMessageParam] + temperature: float | None = 1.0 + top_p: float | None = 1.0 + max_tokens: int | None = None + stream: bool | None = False + stop: str | list[str] | None = None + logprobs: bool | None = None + top_logprobs: int | None = Field(default=None, ge=0, le=20) + presence_penalty: float | None = 0.0 + frequency_penalty: float | None = 0.0 + user: str | None = None + # GGUF model parameters (llama.cpp specific) + n_ctx: int | None = None # Context window size (affects KV cache memory) + n_batch: int | None = ( + None # Batch size for prompt processing (affects compute buffer) + ) + n_gpu_layers: int | None = None # Number of layers to offload to GPU (-1 = all) + n_threads: int | None = None # CPU thread count (None = auto) + flash_attn: bool | None = None # Enable flash attention for faster inference + use_mmap: bool | None = None # Memory-map model file (True = efficient swapping) + use_mlock: bool | None = ( + None # Lock model in RAM (False = allow OS memory management) + ) + cache_type_k: str | None = None # KV cache key quantization (q4_0, q8_0, f16) + cache_type_v: str | None = None # KV cache value quantization (q4_0, q8_0, f16) + extra_body: dict | None = None + + # Tool/function calling parameters + tools: list[ChatCompletionToolParam] | None = None + tool_choice: str | dict | None = ( + None # "auto", "none", "required", or specific tool + ) + + # Thinking/reasoning model parameters (Ollama-compatible) + # Controls whether thinking models show their reasoning process + think: bool | None = None # None = model default, True = enable, False = disable + # Maximum tokens to spend on thinking before forcing answer generation + # When reached, model is nudged to close and provide answer + thinking_budget: int | None = None + + # KV Cache parameters + cache_key: str | None = None # Cache key from /v1/cache/prepare or previous response + return_cache_key: bool | None = None # Return a cache_key in the response for multi-turn chaining + + # Context management parameters + # Whether to automatically truncate messages if context is exceeded + auto_truncate: bool | None = True + # Truncation strategy: "sliding_window", "keep_system", "middle_out", "summarize" + truncation_strategy: str | None = None + + +class ThinkingContent(BaseModel): + """Thinking/reasoning content from a thinking model.""" + + content: str # The raw thinking content (without tags) + tokens: int | None = None # Number of tokens used for thinking + + +class ContextUsageInfo(BaseModel): + """Context window usage information.""" + + total_context: int # Total context window size in tokens + prompt_tokens: int # Tokens used by the prompt (input) + available_for_completion: int # Remaining tokens for output + truncated: bool = False # Whether truncation was applied + truncated_messages: int = 0 # Number of messages removed + strategy_used: str | None = None # Truncation strategy used (if any) + + +class ChatCompletionResponse(BaseModel): + """Extended chat completion response with thinking support.""" + + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: list[dict] + usage: dict + # Ollama-compatible: separate thinking from content + thinking: ThinkingContent | None = None + # Context usage information (extension field) + x_context_usage: ContextUsageInfo | None = None + + +# ============================================================================ +# Audio Content Extraction Utilities +# ============================================================================ + + +def extract_audio_from_messages( + messages: list[ChatCompletionMessageParam], +) -> list[tuple[int, InputAudio]]: + """Extract audio content parts from chat messages. + + Scans messages for input_audio content parts and returns them with + their message index for later replacement if STT fallback is needed. + + Args: + messages: List of chat completion messages + + Returns: + List of (message_index, InputAudio) tuples for each audio part found + """ + audio_parts: list[tuple[int, InputAudio]] = [] + + for idx, message in enumerate(messages): + # Skip if message is a string or has no content + if not isinstance(message, dict): + continue + + content = message.get("content") + if content is None: + continue + + # Handle list of content parts (multimodal message) + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "input_audio": + audio_data = part.get("input_audio", {}) + if isinstance(audio_data, dict) and "data" in audio_data: + audio_parts.append( + ( + idx, + InputAudio( + data=audio_data["data"], + format=audio_data.get("format", "wav"), + ), + ) + ) + + return audio_parts + + +def has_audio_content(messages: list[ChatCompletionMessageParam]) -> bool: + """Check if any messages contain audio content. + + Fast check without extracting the actual audio data. + + Args: + messages: List of chat completion messages + + Returns: + True if any message contains input_audio content + """ + for message in messages: + if not isinstance(message, dict): + continue + + content = message.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "input_audio": + return True + + return False + + +def replace_audio_with_text( + messages: list[ChatCompletionMessageParam], + transcriptions: dict[int, str], +) -> list[dict]: + """Replace audio content parts with transcribed text. + + Used when falling back to STT for models that don't support direct audio. + + Args: + messages: Original messages with audio content + transcriptions: Map of message_index -> transcribed text + + Returns: + New messages list with audio replaced by text + """ + result = [] + + for idx, message in enumerate(messages): + if not isinstance(message, dict): + result.append(message) + continue + + content = message.get("content") + + # If this message had audio and we have a transcription + if idx in transcriptions: + if isinstance(content, list): + # Build new content parts, replacing audio with text + new_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "input_audio": + # Replace with transcribed text + new_parts.append({"type": "text", "text": transcriptions[idx]}) + else: + new_parts.append(part) + + # Consolidate text parts + consolidated = _consolidate_text_parts(new_parts) + result.append({**message, "content": consolidated}) + else: + # Simple string content - shouldn't happen but handle it + result.append(message) + else: + result.append(dict(message) if isinstance(message, dict) else message) + + return result + + +def _consolidate_text_parts(parts: list[dict]) -> str | list[dict]: + """Consolidate adjacent text parts into a single string if possible. + + If the result is all text parts, returns a simple string. + Otherwise returns the list with adjacent text parts merged. + """ + if not parts: + return "" + + # Check if all parts are text + all_text = all( + isinstance(p, dict) and p.get("type") == "text" for p in parts + ) + + if all_text: + # Return simple string + return " ".join(p.get("text", "") for p in parts if isinstance(p, dict)) + + # Otherwise, merge adjacent text parts + result = [] + current_text = [] + + for part in parts: + if isinstance(part, dict) and part.get("type") == "text": + current_text.append(part.get("text", "")) + else: + if current_text: + result.append({"type": "text", "text": " ".join(current_text)}) + current_text = [] + result.append(part) + + if current_text: + result.append({"type": "text", "text": " ".join(current_text)}) + + return result diff --git a/runtimes/edge/routers/completions.py b/runtimes/edge/routers/completions.py new file mode 100644 index 000000000..e7d9f4f7f --- /dev/null +++ b/runtimes/edge/routers/completions.py @@ -0,0 +1,83 @@ +"""OpenAI-compatible text completions endpoint (/v1/completions). + +Accepts a raw prompt string and generates a completion without applying +any chat template. Useful for models that require a specific prompt format +that doesn't align with the GGUF's embedded chat template. +""" + +import logging +import time +import uuid + +from fastapi import APIRouter +from pydantic import BaseModel + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class CompletionRequest(BaseModel): + """OpenAI-compatible completion request.""" + + model: str + prompt: str + temperature: float | None = 1.0 + top_p: float | None = 1.0 + max_tokens: int | None = 512 + stop: str | list[str] | None = None + # GGUF model parameters + n_ctx: int | None = None + n_gpu_layers: int | None = None + + +class CompletionResponse(BaseModel): + """OpenAI-compatible completion response.""" + + id: str + object: str = "text_completion" + created: int + model: str + choices: list[dict] + usage: dict + + +@router.post("/v1/completions") +async def completions(request: CompletionRequest): + """Raw text completions — no chat template applied.""" + from server import load_language + + model = await load_language( + request.model, + n_ctx=request.n_ctx, + n_gpu_layers=request.n_gpu_layers, + ) + + max_tokens = request.max_tokens if request.max_tokens is not None else 512 + stop = request.stop if isinstance(request.stop, list) else ([request.stop] if request.stop else []) + + logger.info(f"[completions] model={request.model} prompt_len={len(request.prompt)} max_tokens={max_tokens}") + + result = await model._generate_from_prompt( + prompt=request.prompt, + max_tokens=max_tokens, + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, + stop=stop, + thinking_budget=None, + ) + + return CompletionResponse( + id=f"cmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request.model, + choices=[{ + "index": 0, + "text": result, + "finish_reason": "stop", + }], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + ) diff --git a/runtimes/edge/routers/health/__init__.py b/runtimes/edge/routers/health/__init__.py new file mode 100644 index 000000000..acce0421a --- /dev/null +++ b/runtimes/edge/routers/health/__init__.py @@ -0,0 +1,5 @@ +"""Health router for health check and models list endpoints.""" + +from .router import router, set_device_info_getter, set_models_cache + +__all__ = ["router", "set_models_cache", "set_device_info_getter"] diff --git a/runtimes/edge/routers/health/router.py b/runtimes/edge/routers/health/router.py new file mode 100644 index 000000000..c652881f0 --- /dev/null +++ b/runtimes/edge/routers/health/router.py @@ -0,0 +1,75 @@ +"""Health router for health check and models list endpoints.""" + +import os +from collections.abc import Callable +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, HTTPException + +from core.logging import UniversalRuntimeLogger + +logger = UniversalRuntimeLogger("universal-runtime.health") + +router = APIRouter(tags=["health"]) + +# Dependency injection for models cache and device info +_models: dict | None = None +_get_device_info_fn: Callable[[], dict[str, Any]] | None = None + + +def set_models_cache(models: dict | None) -> None: + """Set the models cache for health check.""" + global _models + _models = models + + +def set_device_info_getter( + get_device_info_fn: Callable[[], dict[str, Any]] | None, +) -> None: + """Set the device info getter function.""" + global _get_device_info_fn + _get_device_info_fn = get_device_info_fn + + +@router.get("/health") +async def health_check(): + """Health check endpoint with device information.""" + if _models is None or _get_device_info_fn is None: + raise HTTPException( + status_code=500, + detail="Health router not initialized. Call set_models_cache() and set_device_info_getter() first.", + ) + + device_info = _get_device_info_fn() + return { + "status": "healthy", + "device": device_info, + "loaded_models": list(_models.keys()), + "timestamp": datetime.utcnow().isoformat(), + "pid": os.getpid(), + } + + +@router.get("/v1/models") +async def list_models(): + """List currently loaded models.""" + if _models is None: + raise HTTPException( + status_code=500, + detail="Health router not initialized. Call set_models_cache() first.", + ) + + models_list = [] + for model_id, model in _models.items(): + models_list.append( + { + "id": model_id, + "object": "model", + "created": int(datetime.now().timestamp()), + "owned_by": "transformers-runtime", + "type": model.model_type, + } + ) + + return {"object": "list", "data": models_list} diff --git a/runtimes/edge/routers/vision/__init__.py b/runtimes/edge/routers/vision/__init__.py new file mode 100644 index 000000000..93251f49c --- /dev/null +++ b/runtimes/edge/routers/vision/__init__.py @@ -0,0 +1,32 @@ +"""Vision routers for edge runtime — detection, classification, and streaming only. + +Excludes: OCR, document extraction, training, evaluation, tracking, sample data, models. +""" + +from fastapi import APIRouter + +from .classification import router as classification_router +from .classification import set_classification_loader +from .detect_classify import router as detect_classify_router +from .detect_classify import set_detect_classify_loaders +from .detection import router as detection_router +from .detection import set_detection_loader +from .streaming import router as streaming_router +from .streaming import set_streaming_detection_loader, start_session_cleanup, stop_session_cleanup + +# Combined router — edge subset only +router = APIRouter(tags=["vision"]) +router.include_router(detection_router) +router.include_router(classification_router) +router.include_router(detect_classify_router) +router.include_router(streaming_router) + +__all__ = [ + "router", + "set_detection_loader", + "set_classification_loader", + "set_detect_classify_loaders", + "set_streaming_detection_loader", + "start_session_cleanup", + "stop_session_cleanup", +] diff --git a/runtimes/edge/routers/vision/classification.py b/runtimes/edge/routers/vision/classification.py new file mode 100644 index 000000000..d6347a674 --- /dev/null +++ b/runtimes/edge/routers/vision/classification.py @@ -0,0 +1,61 @@ +"""Classification router — POST /v1/vision/classify""" + +import logging +import time +from collections.abc import Callable, Coroutine +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from services.error_handler import handle_endpoint_errors + +from .utils import decode_base64_image + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["vision-classification"]) + +_load_fn: Callable[..., Coroutine[Any, Any, Any]] | None = None + + +def set_classification_loader(load_fn: Callable[..., Coroutine[Any, Any, Any]] | None) -> None: + global _load_fn + _load_fn = load_fn + + +class ClassifyRequest(BaseModel): + image: str = Field(..., description="Base64-encoded image") + model: str = "clip-vit-base" + classes: list[str] = Field(..., description="Classes for zero-shot classification") + top_k: int = Field(default=5, ge=1, le=100) + +class ClassifyResponse(BaseModel): + class_name: str + class_id: int + confidence: float + all_scores: dict[str, float] + model: str + inference_time_ms: float + + +@router.post("/v1/vision/classify", response_model=ClassifyResponse) +@handle_endpoint_errors("vision_classify") +async def classify_image(request: ClassifyRequest) -> ClassifyResponse: + """Classify an image using CLIP (zero-shot).""" + if _load_fn is None: + raise HTTPException(status_code=500, detail="Classification loader not initialized") + if not request.classes: + raise HTTPException(status_code=400, detail="Classes required for zero-shot classification") + + start = time.perf_counter() + model = await _load_fn(request.model) + image_bytes = decode_base64_image(request.image) + + result = await model.classify(image=image_bytes, classes=request.classes, top_k=request.top_k) + + return ClassifyResponse( + class_name=result.class_name, class_id=result.class_id, + confidence=result.confidence, all_scores=result.all_scores, + model=request.model, + inference_time_ms=(time.perf_counter() - start) * 1000, + ) diff --git a/runtimes/edge/routers/vision/detect_classify.py b/runtimes/edge/routers/vision/detect_classify.py new file mode 100644 index 000000000..5777038eb --- /dev/null +++ b/runtimes/edge/routers/vision/detect_classify.py @@ -0,0 +1,180 @@ +"""Detect+Classify combo endpoint — YOLO detect → crop → CLIP classify per crop.""" + +import io +import logging +import time +from collections.abc import Callable, Coroutine +from typing import Any + +from fastapi import APIRouter, HTTPException +from PIL import Image, UnidentifiedImageError +from pydantic import BaseModel, Field + +from services.error_handler import handle_endpoint_errors + +from .utils import decode_base64_image + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["vision-detect-classify"]) + +_load_detection_fn: Callable[..., Coroutine[Any, Any, Any]] | None = None +_load_classification_fn: Callable[..., Coroutine[Any, Any, Any]] | None = None + + +def set_detect_classify_loaders( + detection_fn: Callable[..., Coroutine[Any, Any, Any]] | None, + classification_fn: Callable[..., Coroutine[Any, Any, Any]] | None, +) -> None: + global _load_detection_fn, _load_classification_fn + _load_detection_fn = detection_fn + _load_classification_fn = classification_fn + + +# ============================================================================= +# Request/Response models +# ============================================================================= + +class BoundingBox(BaseModel): + x1: float + y1: float + x2: float + y2: float + + +class ClassifiedDetection(BaseModel): + """A detection with classification results.""" + box: BoundingBox + detection_class: str + detection_confidence: float + classification: str + classification_confidence: float + all_scores: dict[str, float] + + +class DetectClassifyRequest(BaseModel): + image: str = Field(..., description="Base64-encoded image") + detection_model: str = Field(default="yolov8n", description="YOLO model for detection") + classification_model: str = Field(default="clip-vit-base", description="CLIP model for classification") + classes: list[str] = Field(..., description="Classes for zero-shot classification of each crop") + confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0, description="Detection confidence threshold") + detection_classes: list[str] | None = Field(default=None, description="Filter detections to these YOLO classes") + top_k: int = Field(default=3, ge=1, le=100, description="Top-K classification results per crop") + min_crop_px: int = Field(default=16, ge=1, description="Minimum crop dimension in pixels (skip tiny detections)") + + +class DetectClassifyResponse(BaseModel): + results: list[ClassifiedDetection] + total_detections: int + classified_count: int + detection_model: str + classification_model: str + detection_time_ms: float + classification_time_ms: float + total_time_ms: float + + +# ============================================================================= +# Endpoint +# ============================================================================= + +@router.post("/v1/vision/detect_classify", response_model=DetectClassifyResponse) +@handle_endpoint_errors("vision_detect_classify") +async def detect_and_classify(request: DetectClassifyRequest) -> DetectClassifyResponse: + """Detect objects then classify each crop — single round-trip. + + Runs YOLO detection → crops each bounding box → CLIP classifies each crop. + Returns unified results with both detection and classification info. + """ + if _load_detection_fn is None or _load_classification_fn is None: + raise HTTPException(status_code=500, detail="Model loaders not initialized") + if not request.classes: + raise HTTPException(status_code=400, detail="Classes required for classification") + + total_start = time.perf_counter() + image_bytes = decode_base64_image(request.image) + + # Step 1: Detect + det_start = time.perf_counter() + det_model = await _load_detection_fn(request.detection_model) + det_result = await det_model.detect( + image=image_bytes, + confidence_threshold=request.confidence_threshold, + classes=request.detection_classes, + ) + det_time = (time.perf_counter() - det_start) * 1000 + + total_detections = len(det_result.boxes) + if total_detections == 0: + return DetectClassifyResponse( + results=[], total_detections=0, classified_count=0, + detection_model=request.detection_model, + classification_model=request.classification_model, + detection_time_ms=det_time, classification_time_ms=0.0, + total_time_ms=(time.perf_counter() - total_start) * 1000, + ) + + # Step 2: Crop each detection and classify + cls_start = time.perf_counter() + cls_model = await _load_classification_fn(request.classification_model) + + # Convert image once for cropping + try: + pil_image = Image.open(io.BytesIO(image_bytes)) + pil_image.load() + except UnidentifiedImageError as e: + raise ValueError( + "Cannot identify image format. " + "Ensure the image is a valid JPEG, PNG, BMP, TIFF, or WebP file." + ) from e + except OSError as e: + raise ValueError(f"Failed to decode image data: {e}") from e + results: list[ClassifiedDetection] = [] + + for box in det_result.boxes: + # Crop the detection region + x1, y1 = max(0, int(box.x1)), max(0, int(box.y1)) + x2, y2 = min(pil_image.width, int(box.x2)), min(pil_image.height, int(box.y2)) + + # Skip tiny crops + if (x2 - x1) < request.min_crop_px or (y2 - y1) < request.min_crop_px: + continue + + crop = pil_image.crop((x1, y1, x2, y2)) + + # Ensure RGB mode for JPEG encoding (handles RGBA, P, L, etc.) + if crop.mode != "RGB": + crop = crop.convert("RGB") + + # Convert crop to bytes for the classifier + buf = io.BytesIO() + crop.save(buf, format="JPEG", quality=90) + crop_bytes = buf.getvalue() + + # Classify the crop + cls_result = await cls_model.classify( + image=crop_bytes, + classes=request.classes, + top_k=request.top_k, + ) + + results.append(ClassifiedDetection( + box=BoundingBox(x1=box.x1, y1=box.y1, x2=box.x2, y2=box.y2), + detection_class=box.class_name, + detection_confidence=box.confidence, + classification=cls_result.class_name, + classification_confidence=cls_result.confidence, + all_scores=cls_result.all_scores, + )) + + cls_time = (time.perf_counter() - cls_start) * 1000 + + return DetectClassifyResponse( + results=results, + total_detections=total_detections, + classified_count=len(results), + detection_model=request.detection_model, + classification_model=request.classification_model, + detection_time_ms=det_time, + classification_time_ms=cls_time, + total_time_ms=(time.perf_counter() - total_start) * 1000, + ) diff --git a/runtimes/edge/routers/vision/detection.py b/runtimes/edge/routers/vision/detection.py new file mode 100644 index 000000000..a5b825af5 --- /dev/null +++ b/runtimes/edge/routers/vision/detection.py @@ -0,0 +1,76 @@ +"""Detection router — POST /v1/vision/detect""" + +import logging +import time +from collections.abc import Callable, Coroutine +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from services.error_handler import handle_endpoint_errors + +from .utils import decode_base64_image + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["vision-detection"]) + +_load_fn: Callable[..., Coroutine[Any, Any, Any]] | None = None + + +def set_detection_loader(load_fn: Callable[..., Coroutine[Any, Any, Any]] | None) -> None: + global _load_fn + _load_fn = load_fn + + +class BoundingBox(BaseModel): + x1: float + y1: float + x2: float + y2: float + +class Detection(BaseModel): + box: BoundingBox + class_name: str + class_id: int + confidence: float + +class DetectRequest(BaseModel): + image: str = Field(..., description="Base64-encoded image") + model: str = "yolov8n" + confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0) + classes: list[str] | None = None + +class DetectResponse(BaseModel): + detections: list[Detection] + model: str + inference_time_ms: float + + +@router.post("/v1/vision/detect", response_model=DetectResponse) +@handle_endpoint_errors("vision_detect") +async def detect_objects(request: DetectRequest) -> DetectResponse: + """Detect objects in an image using YOLO.""" + if _load_fn is None: + raise HTTPException(status_code=500, detail="Detection loader not initialized") + + start = time.perf_counter() + model = await _load_fn(request.model) + image_bytes = decode_base64_image(request.image) + + result = await model.detect( + image=image_bytes, + confidence_threshold=request.confidence_threshold, + classes=request.classes, + ) + + return DetectResponse( + detections=[ + Detection( + box=BoundingBox(x1=b.x1, y1=b.y1, x2=b.x2, y2=b.y2), + class_name=b.class_name, class_id=b.class_id, confidence=b.confidence, + ) for b in result.boxes + ], + model=request.model, + inference_time_ms=(time.perf_counter() - start) * 1000, + ) diff --git a/runtimes/edge/routers/vision/streaming.py b/runtimes/edge/routers/vision/streaming.py new file mode 100644 index 000000000..977b52ead --- /dev/null +++ b/runtimes/edge/routers/vision/streaming.py @@ -0,0 +1,385 @@ +"""Streaming vision router — simplified cascade detection. + +Cascade: if confidence < threshold, try next model in chain. +Chain can include "remote:{url}" entries for Atmosphere readiness. +""" + +import asyncio +import logging +import time +import uuid +from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import urlparse + +import httpx +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from services.error_handler import handle_endpoint_errors + +from .utils import decode_base64_image + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["vision-streaming"]) + +# Dependency injection +_load_detection_fn: Callable[..., Coroutine[Any, Any, Any]] | None = None + +# SSRF protection: allowlist of remote hosts for cascade +_ALLOWED_REMOTE_HOSTS: set[str] = set() + + +def set_streaming_detection_loader(fn: Callable[..., Coroutine[Any, Any, Any]] | None) -> None: + global _load_detection_fn + _load_detection_fn = fn + + +def set_allowed_remote_hosts(hosts: set[str]) -> None: + """Set allowlist of remote hosts for cascade (SSRF mitigation).""" + global _ALLOWED_REMOTE_HOSTS + _ALLOWED_REMOTE_HOSTS = hosts + + +# ============================================================================= +# Session management +# ============================================================================= + +@dataclass +class CascadeConfig: + """Cascade chain config. Models tried in order.""" + chain: list[str] = field(default_factory=lambda: ["yolov8n"]) + confidence_threshold: float = 0.7 + +@dataclass +class StreamSession: + session_id: str + cascade: CascadeConfig + target_fps: float = 1.0 + action_classes: list[str] | None = None + cooldown_seconds: float = 5.0 + frames_processed: int = 0 + actions_triggered: int = 0 + escalations: int = 0 + created_at: float = field(default_factory=time.time) + last_action_at: float = 0.0 + last_frame_at: float = field(default_factory=time.time) + +_sessions: dict[str, StreamSession] = {} +_http_client: httpx.AsyncClient | None = None +_cleanup_task: asyncio.Task | None = None +SESSION_TTL_SECONDS: float = 60.0 # Auto-expire after no frames for this long + + +async def _session_cleanup_loop() -> None: + """Background task that expires orphaned streaming sessions.""" + while True: + await asyncio.sleep(15) # Check every 15 seconds + now = time.time() + expired = [ + sid for sid, s in _sessions.items() + if (now - s.last_frame_at) > SESSION_TTL_SECONDS + ] + for sid in expired: + session = _sessions.pop(sid, None) + if session: + logger.info( + f"Expired orphaned stream session {sid} " + f"(idle {now - session.last_frame_at:.0f}s, " + f"{session.frames_processed} frames processed)" + ) + + +def start_session_cleanup() -> None: + """Start the background session cleanup task. Call once at server startup.""" + global _cleanup_task + if _cleanup_task is None or _cleanup_task.done(): + _cleanup_task = asyncio.create_task(_session_cleanup_loop()) + + +async def stop_session_cleanup() -> None: + """Cancel the background session cleanup task (call during shutdown).""" + global _cleanup_task + if _cleanup_task is not None and not _cleanup_task.done(): + _cleanup_task.cancel() + import contextlib + with contextlib.suppress(asyncio.CancelledError): + await _cleanup_task + logger.info("Vision session cleanup task stopped") + _cleanup_task = None + + +def _get_http_client() -> httpx.AsyncClient: + global _http_client + if _http_client is None or _http_client.is_closed: + _http_client = httpx.AsyncClient(timeout=10.0) + return _http_client + + +# ============================================================================= +# Request/Response models +# ============================================================================= + +class CascadeConfigRequest(BaseModel): + chain: list[str] = Field(default=["yolov8n"], description="Model chain, can include 'remote:http://...'") + confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0) + +class StreamStartRequest(BaseModel): + config: CascadeConfigRequest = Field(default_factory=CascadeConfigRequest) + target_fps: float = 1.0 + action_classes: list[str] | None = None + cooldown_seconds: float = 5.0 + +class StreamStartResponse(BaseModel): + session_id: str + +class StreamFrameRequest(BaseModel): + session_id: str + image: str = Field(..., description="Base64-encoded image") + +class DetectionItem(BaseModel): + x1: float + y1: float + x2: float + y2: float + class_name: str + class_id: int + confidence: float + +class StreamFrameResponse(BaseModel): + status: str # "ok", "action", "escalated" + detections: list[DetectionItem] | None = None + confidence: float | None = None + resolved_by: str | None = None + +class StreamStopRequest(BaseModel): + session_id: str + +class StreamStopResponse(BaseModel): + session_id: str + frames_processed: int + actions_triggered: int + escalations: int + duration_seconds: float + + +# ============================================================================= +# Endpoints +# ============================================================================= + +@router.post("/v1/vision/stream/start", response_model=StreamStartResponse) +@handle_endpoint_errors("vision_stream_start") +async def start_stream(request: StreamStartRequest) -> StreamStartResponse: + """Start a streaming detection session with cascade config.""" + # Limit concurrent sessions to prevent memory growth + MAX_SESSIONS = 100 + if len(_sessions) >= MAX_SESSIONS: + raise HTTPException(status_code=429, detail=f"Max {MAX_SESSIONS} concurrent sessions") + sid = str(uuid.uuid4())[:8] + _sessions[sid] = StreamSession( + session_id=sid, + cascade=CascadeConfig( + chain=request.config.chain, + confidence_threshold=request.config.confidence_threshold, + ), + target_fps=request.target_fps, + action_classes=request.action_classes, + cooldown_seconds=request.cooldown_seconds, + ) + return StreamStartResponse(session_id=sid) + + +@router.post("/v1/vision/stream/frame", response_model=StreamFrameResponse) +@handle_endpoint_errors("vision_stream_frame") +async def process_frame(request: StreamFrameRequest) -> StreamFrameResponse: + """Process a frame through the cascade chain.""" + session = _sessions.get(request.session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + if _load_detection_fn is None: + raise HTTPException(status_code=500, detail="Detection loader not initialized") + + session.frames_processed += 1 + session.last_frame_at = time.time() + image_bytes = decode_base64_image(request.image) + + # Try each model in the cascade chain + for i, model_ref in enumerate(session.cascade.chain): + if model_ref.startswith("remote:"): + # Remote model — HTTP POST + url = model_ref[7:] # strip "remote:" + result = await _call_remote(url, image_bytes, session) + if result and result.confidence >= session.cascade.confidence_threshold: + if i > 0: + session.escalations += 1 + return _build_response(result, model_ref, i > 0, session) + else: + # Local model + model = await _load_detection_fn(model_ref) + det_result = await model.detect( + image=image_bytes, + confidence_threshold=0.1, # Low threshold, we check ourselves + classes=session.action_classes, + ) + if det_result.confidence >= session.cascade.confidence_threshold: + if i > 0: + session.escalations += 1 + return _build_response(det_result, model_ref, i > 0, session) + + # No model in chain was confident enough + return StreamFrameResponse(status="ok") + + +@router.post("/v1/vision/stream/stop", response_model=StreamStopResponse) +@handle_endpoint_errors("vision_stream_stop") +async def stop_stream(request: StreamStopRequest) -> StreamStopResponse: + """Stop a streaming session.""" + session = _sessions.pop(request.session_id, None) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + return StreamStopResponse( + session_id=session.session_id, + frames_processed=session.frames_processed, + actions_triggered=session.actions_triggered, + escalations=session.escalations, + duration_seconds=time.time() - session.created_at, + ) + + +class SessionInfo(BaseModel): + session_id: str + frames_processed: int + actions_triggered: int + escalations: int + chain: list[str] + idle_seconds: float + duration_seconds: float + + +class SessionsListResponse(BaseModel): + sessions: list[SessionInfo] + count: int + + +@router.get("/v1/vision/stream/sessions", response_model=SessionsListResponse) +@handle_endpoint_errors("vision_stream_sessions") +async def list_sessions() -> SessionsListResponse: + """List active streaming sessions.""" + now = time.time() + sessions = [ + SessionInfo( + session_id=s.session_id, + frames_processed=s.frames_processed, + actions_triggered=s.actions_triggered, + escalations=s.escalations, + chain=s.cascade.chain, + idle_seconds=round(now - s.last_frame_at, 1), + duration_seconds=round(now - s.created_at, 1), + ) + for s in _sessions.values() + ] + return SessionsListResponse(sessions=sessions, count=len(sessions)) + + +# ============================================================================= +# Helpers +# ============================================================================= + +def _build_response(det_result: Any, model_ref: str, escalated: bool, + session: StreamSession) -> StreamFrameResponse: + """Build response from detection result.""" + # Check cooldown + now = time.time() + if now - session.last_action_at < session.cooldown_seconds: + return StreamFrameResponse(status="ok") + + session.actions_triggered += 1 + session.last_action_at = now + + detections = [] + if hasattr(det_result, "boxes"): + detections = [ + DetectionItem( + x1=b.x1, y1=b.y1, x2=b.x2, y2=b.y2, + class_name=b.class_name, class_id=b.class_id, confidence=b.confidence, + ) for b in det_result.boxes + ] + + return StreamFrameResponse( + status="escalated" if escalated else "action", + detections=detections, + confidence=det_result.confidence if hasattr(det_result, "confidence") else None, + resolved_by=model_ref, + ) + + +async def _call_remote(url: str, image_bytes: bytes, session: StreamSession) -> Any | None: + """Call a remote vision detection endpoint. + + SSRF Protection: Only calls URLs with hosts in the allowlist. + If allowlist is empty, all remote calls are rejected. + """ + import base64 + + # Validate URL against allowlist + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + logger.warning(f"Invalid scheme in remote URL: {url}") + return None + if not _ALLOWED_REMOTE_HOSTS: + logger.warning("Remote cascade disabled: no allowed hosts configured") + raise HTTPException(status_code=403, detail="Remote cascade not allowed") + if parsed.hostname not in _ALLOWED_REMOTE_HOSTS: + logger.warning(f"Remote host {parsed.hostname} not in allowlist") + raise HTTPException(status_code=403, detail=f"Remote host not allowed: {parsed.hostname}") + except ValueError as e: + logger.warning(f"Malformed remote URL: {url} - {e}") + return None + + try: + client = _get_http_client() + resp = await client.post(url, json={ + "image": base64.b64encode(image_bytes).decode(), + "confidence_threshold": session.cascade.confidence_threshold, + "classes": session.action_classes, + }) + if resp.status_code == 200: + data = resp.json() + return _RemoteResult(data) + except Exception as e: + logger.warning(f"Remote cascade call to {url} failed: {e}") + return None + + +@dataclass +class _RemoteBox: + """Bounding box from a remote detection result.""" + x1: float + y1: float + x2: float + y2: float + class_name: str + class_id: int + confidence: float + + +class _RemoteResult: + """Simple wrapper for remote detection results.""" + def __init__(self, data: dict): + dets = data.get("detections", []) + self.confidence = max((d.get("confidence", 0) for d in dets), default=0.0) + self.boxes = [] + for d in dets: + box = d.get("box", {}) + try: + self.boxes.append(_RemoteBox( + x1=box.get("x1", 0), y1=box.get("y1", 0), + x2=box.get("x2", 0), y2=box.get("y2", 0), + class_name=d.get("class_name", "unknown"), + class_id=d.get("class_id", 0), + confidence=d.get("confidence", 0), + )) + except (KeyError, TypeError) as e: + logger.warning(f"Skipping malformed remote detection: {e}") diff --git a/runtimes/edge/routers/vision/utils.py b/runtimes/edge/routers/vision/utils.py new file mode 100644 index 000000000..4af496fe7 --- /dev/null +++ b/runtimes/edge/routers/vision/utils.py @@ -0,0 +1,22 @@ +"""Shared utilities for vision routers.""" + +import base64 + +from fastapi import HTTPException + + +def decode_base64_image(image_str: str) -> bytes: + """Decode base64 image string to bytes. Handles data URI format and line-wrapped base64.""" + if image_str.startswith("data:"): + if "," not in image_str: + raise HTTPException(status_code=400, detail="Malformed data URI") + _, base64_data = image_str.split(",", 1) + else: + base64_data = image_str + # Strip whitespace — handles trailing newlines from tools like `jq -Rs` and + # line-wrapped base64 produced by GNU/BSD `base64` commands. + base64_data = "".join(base64_data.split()) + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise HTTPException(status_code=400, detail="Invalid base64 image data") from e diff --git a/runtimes/edge/server.py b/runtimes/edge/server.py new file mode 100644 index 000000000..77194fca3 --- /dev/null +++ b/runtimes/edge/server.py @@ -0,0 +1,475 @@ +""" +LlamaFarm Edge Runtime + +A stripped-down FastAPI server for on-device inference. +Designed for constrained hardware (Raspberry Pi, Jetson, etc.) + +Supports: +- LLM inference (GGUF via llama.cpp) +- Vision detection (YOLO — Hailo-10H accelerated or CPU fallback) +- Health checks + +This is the "runtime plane" — no RAG, no UI, no model management. +Models are pre-loaded on device. + +Environment Variables: +- MODEL_UNLOAD_TIMEOUT: Seconds of inactivity before unloading models (default: 300) +- CLEANUP_CHECK_INTERVAL: Seconds between cleanup checks (default: 30) +- LF_RUNTIME_PORT: Server port (default: 11540) +- LF_RUNTIME_HOST: Server host (default: 0.0.0.0) +- HAILO_HEF_DIR: Directory containing .hef model files (default: /models) +- FORCE_CPU_VISION: Set to "1" to skip Hailo detection and use CPU (default: unset) +""" + +import asyncio +import functools +import os +import subprocess +import warnings +from contextlib import asynccontextmanager, suppress + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from core.logging import UniversalRuntimeLogger, setup_logging +from models import ( + BaseModel, + GGUFLanguageModel, + LanguageModel, +) +from routers.chat_completions import router as chat_completions_router +from routers.completions import router as completions_router +from routers.chat_completions.service import ChatCompletionsService +from routers.health import ( + router as health_router, + set_device_info_getter, + set_models_cache, +) +from routers.vision import ( + router as vision_router, + set_detection_loader, + set_classification_loader, + set_detect_classify_loaders, + set_streaming_detection_loader, + start_session_cleanup, + stop_session_cleanup, +) +from utils.device import get_device_info, get_optimal_device +from utils.model_cache import ModelCache +from utils.model_format import detect_model_format +from utils.safe_home import get_data_dir +from services.zenoh_ipc import ZenohIPC + +# Suppress spurious warnings +warnings.filterwarnings( + "ignore", + message=r"resource_tracker: There appear to be \d+ leaked semaphore", + category=UserWarning, +) + +# Configure logging +log_file = os.getenv("LOG_FILE", "") +log_level = os.getenv("LOG_LEVEL", "INFO") +json_logs = os.getenv("LOG_JSON_FORMAT", "false").lower() in ("true", "1", "yes") +setup_logging(json_logs=json_logs, log_level=log_level, log_file=log_file) + +logger = UniversalRuntimeLogger("edge-runtime") + + +def _init_llama_backend(): + """Initialize llama.cpp backend in the main thread. + Critical for Jetson/Tegra CUDA stability on unified memory architectures. + """ + try: + from llamafarm_llama._bindings import ensure_backend + logger.info("Initializing llama.cpp backend in main thread...") + ensure_backend() + logger.info("llama.cpp backend initialized successfully") + except ImportError: + logger.debug("llamafarm_llama not installed, skipping backend init") + except Exception as e: + logger.warning(f"Failed to initialize llama.cpp backend: {e}") + + +_init_llama_backend() + + +# Model unload timeout configuration +MODEL_UNLOAD_TIMEOUT = int(os.getenv("MODEL_UNLOAD_TIMEOUT", "300")) +CLEANUP_CHECK_INTERVAL = int(os.getenv("CLEANUP_CHECK_INTERVAL", "30")) + +# Model cache +_models: ModelCache[BaseModel] = ModelCache(ttl=MODEL_UNLOAD_TIMEOUT) +_model_load_lock = asyncio.Lock() +_current_device = None +_cleanup_task: asyncio.Task | None = None +_zenoh_ipc: ZenohIPC | None = None + +# Data directories +_LF_DATA_DIR = get_data_dir() +VISION_MODELS_DIR = _LF_DATA_DIR / "models" / "vision" + + +def get_device(): + """Get the optimal device for the current platform.""" + global _current_device + if _current_device is None: + _current_device = get_optimal_device() + logger.info(f"Using device: {_current_device}") + return _current_device + + +# ============================================================================ +# Hardware Detection +# ============================================================================ + +@functools.lru_cache(maxsize=1) +def _detect_hailo() -> bool: + """Detect if Hailo-10H PCIe device is present. + + Checks for PCI device ID 1e60:45c4 (Hailo-10H) via lspci, + and verifies hailo_platform is importable. + """ + if os.getenv("FORCE_CPU_VISION", "").lower() in ("1", "true", "yes"): + logger.info("Hailo detection skipped (FORCE_CPU_VISION=1)") + return False + + # Check for hailo_platform package + try: + import hailo_platform # noqa: F401 + except ImportError: + logger.info("hailo_platform not installed, using CPU backend for vision") + return False + + # Check for PCIe device + try: + result = subprocess.run( + ["lspci", "-d", "1e60:"], + capture_output=True, text=True, timeout=5, + ) + if result.stdout.strip(): + logger.info("Hailo-10H detected, using Hailo backend for vision") + return True + except (FileNotFoundError, subprocess.TimeoutExpired): + # lspci not available (macOS) or timed out + pass + + # Fallback: check for /dev/hailo0 + if os.path.exists("/dev/hailo0"): + logger.info("Hailo device found at /dev/hailo0, using Hailo backend") + return True + + logger.info("Hailo not detected, using CPU backend for vision") + return False + + +async def _cleanup_idle_models() -> None: + """Background task that periodically unloads idle models.""" + logger.info( + f"Model cleanup task started (timeout={MODEL_UNLOAD_TIMEOUT}s, " + f"check_interval={CLEANUP_CHECK_INTERVAL}s)" + ) + while True: + try: + await asyncio.sleep(CLEANUP_CHECK_INTERVAL) + expired_items = _models.pop_expired() + if expired_items: + logger.info(f"Unloading {len(expired_items)} idle models") + for cache_key, model in expired_items: + try: + await model.unload() + logger.info(f"Successfully unloaded: {cache_key}") + except Exception as e: + logger.error(f"Error unloading model {cache_key}: {e}") + except asyncio.CancelledError: + logger.info("Model cleanup task cancelled") + break + except Exception as e: + logger.error(f"Error in cleanup task: {e}", exc_info=True) + + +# ============================================================================ +# Language Model Loading +# ============================================================================ + + +async def load_language( + model_id: str, + n_ctx: int | None = None, + n_batch: int | None = None, + n_gpu_layers: int | None = None, + n_threads: int | None = None, + flash_attn: bool | None = None, + use_mmap: bool | None = None, + use_mlock: bool | None = None, + cache_type_k: str | None = None, + cache_type_v: str | None = None, + preferred_quantization: str | None = None, +): + """Load a causal language model (GGUF or transformers format).""" + # Reject model IDs with path traversal sequences + if ".." in model_id or model_id.startswith(("/", "\\")) or "\\" in model_id or (len(model_id) > 1 and model_id[1] == ":"): + raise ValueError(f"Invalid model_id: {model_id}") + + quant_key = preferred_quantization or "default" + cache_key = ( + f"language:{model_id}:ctx{n_ctx or 'auto'}:gpu{n_gpu_layers or 'auto'}:" + f"quant{quant_key}" + ) + + if cache_key not in _models: + async with _model_load_lock: + if cache_key not in _models: + logger.info(f"Loading causal LM: {model_id}") + device = get_device() + model_format = detect_model_format(model_id) + logger.info(f"Detected format: {model_format}") + + model: BaseModel + if model_format == "gguf": + model = GGUFLanguageModel( + model_id, device, + n_ctx=n_ctx, n_batch=n_batch, + n_gpu_layers=n_gpu_layers, n_threads=n_threads, + flash_attn=flash_attn, use_mmap=use_mmap, + use_mlock=use_mlock, cache_type_k=cache_type_k, + cache_type_v=cache_type_v, + preferred_quantization=preferred_quantization, + ) + else: + model = LanguageModel(model_id, device) + + await model.load() + _models[cache_key] = model + + return _models.get(cache_key) + + +# ============================================================================ +# Vision Model Loading +# ============================================================================ + + +async def load_detection_model(model_id: str = "yolov8n"): + """Load a YOLO detection model. + + Auto-selects backend: + - Hailo-10H: loads .hef model on the AI accelerator + - CPU fallback: loads .pt model via ultralytics/PyTorch + """ + backend = "hailo" if _detect_hailo() else "cpu" + cache_key = f"vision:detect:{backend}:{model_id}" + + if cache_key not in _models: + async with _model_load_lock: + if cache_key not in _models: + from pathlib import Path as _Path + + safe_id = _Path(model_id).name + if safe_id != model_id or safe_id in (".", ".."): + raise ValueError(f"Invalid model_id: {model_id}") + # Verify resolved path stays within VISION_MODELS_DIR + vision_root = VISION_MODELS_DIR.resolve() + resolved = (VISION_MODELS_DIR / safe_id).resolve() + if not str(resolved).startswith(str(vision_root) + os.sep): + raise ValueError(f"Invalid model_id: {model_id}") + + if backend == "hailo": + from models.hailo_model import HailoYOLOModel + + hef_dir = os.getenv("HAILO_HEF_DIR", "/models") + model = HailoYOLOModel( + model_id=model_id, + confidence_threshold=0.5, + hef_dir=hef_dir, + ) + else: + from models.yolo_model import YOLOModel + + device = get_device() + custom_path = resolved / "current.pt" + mid = str(custom_path) if custom_path.exists() else model_id + model = YOLOModel(model_id=mid, device=device) + + await model.load() + _models[cache_key] = model + + return _models[cache_key] + + +async def load_classification_model(model_id: str = "clip-vit-base"): + """Load a CLIP classification model.""" + # Validate model_id: must be a known variant or a valid HuggingFace repo ID + # (org/model format). Reject path-like IDs that could reach the filesystem. + from models.clip_model import CLIP_VARIANTS + if model_id not in CLIP_VARIANTS: + if "/" not in model_id or model_id.startswith(("/", "\\", ".")) or ".." in model_id or "\\" in model_id or ":" in model_id: + raise ValueError(f"Invalid classification model_id: {model_id}") + + cache_key = f"vision:classify:{model_id}" + if cache_key not in _models: + async with _model_load_lock: + if cache_key not in _models: + from models.clip_model import CLIPModel + device = get_device() + model = CLIPModel(model_id=model_id, device=device) + await model.load() + _models[cache_key] = model + return _models[cache_key] + + +# ============================================================================ +# Zenoh IPC Inference Bridge +# ============================================================================ + + +async def _zenoh_inference(request: dict) -> str: + """Bridge between Zenoh request JSON and the model inference path.""" + model_id = request.get("model", "") + messages = request.get("messages", []) + max_tokens = request.get("max_tokens", 256) + temperature = request.get("temperature", 0.7) + + model = await load_language(model_id) + return await model.generate( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + + +# ============================================================================ +# Lifecycle +# ============================================================================ + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifecycle.""" + global _cleanup_task, _zenoh_ipc + + logger.info("Starting LlamaFarm Edge Runtime") + + _cleanup_task = asyncio.create_task(_cleanup_idle_models()) + + # Start KV cache manager + from utils.kv_cache_manager import ( + KVCacheManager, start_kv_cache_gc, stop_kv_cache_gc, + ) + global _kv_cache_manager + _kv_cache_manager = KVCacheManager() + from routers.cache import set_cache_manager, set_cache_language_loader + set_cache_manager(_kv_cache_manager) + set_cache_language_loader(load_language) + ChatCompletionsService.set_cache_manager(_kv_cache_manager) + start_kv_cache_gc(_kv_cache_manager) + + start_session_cleanup() + + # Start Zenoh IPC interface (non-blocking — falls back to HTTP-only on failure) + _zenoh_ipc = ZenohIPC(inference_fn=_zenoh_inference) + await _zenoh_ipc.start() + + yield + + # Shutdown + logger.info("Shutting down Edge Runtime") + + if _zenoh_ipc is not None: + await _zenoh_ipc.stop() + + await stop_kv_cache_gc() + await stop_session_cleanup() + + if _cleanup_task is not None: + _cleanup_task.cancel() + with suppress(asyncio.CancelledError): + await _cleanup_task + + for cache_key, model in list(_models.items()): + try: + await model.unload() + except Exception as e: + logger.error(f"Error unloading {cache_key}: {e}") + _models.clear() + + logger.info("Shutdown complete") + + +# ============================================================================ +# App +# ============================================================================ + +_kv_cache_manager = None + +app = FastAPI( + title="LlamaFarm Edge Runtime", + description="Minimal on-device inference API for drones and edge hardware", + version="0.1.0", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Edge device — open CORS + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Only the routers the drone needs +app.include_router(health_router) +app.include_router(chat_completions_router) +app.include_router(completions_router) +app.include_router(vision_router) + + +@app.post("/v1/models/unload", tags=["models"]) +async def unload_all_models(): + """Unload all loaded models to free memory.""" + unloaded = [] + for cache_key, model in list(_models.items()): + try: + await model.unload() + unloaded.append(cache_key) + except Exception as e: + logger.error(f"Error unloading {cache_key}: {e}") + _models.clear() + return {"unloaded": len(unloaded), "models": unloaded} + + +# ============================================================================ +# Router Dependency Injection +# ============================================================================ + +set_models_cache(_models) +set_device_info_getter(get_device_info) +set_detection_loader(load_detection_model) +set_classification_loader(load_classification_model) +set_detect_classify_loaders(load_detection_model, load_classification_model) +set_streaming_detection_loader(load_detection_model) + + +# ============================================================================ +# Entry Point +# ============================================================================ + +if __name__ == "__main__": + import uvicorn + from llamafarm_common.pidfile import write_pid + + write_pid("edge-runtime") + + port = int(os.getenv("LF_RUNTIME_PORT", os.getenv("PORT", "11540"))) + host = os.getenv("LF_RUNTIME_HOST", os.getenv("HOST", "0.0.0.0")) + + logger.info(f"Starting LlamaFarm Edge Runtime on {host}:{port}") + logger.info(f"Device: {get_device()}") + + uvicorn.run( + app, + host=host, + port=port, + log_config=None, + access_log=False, + ) diff --git a/runtimes/edge/services/__init__.py b/runtimes/edge/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/runtimes/edge/services/error_handler.py b/runtimes/edge/services/error_handler.py new file mode 100644 index 000000000..541be5d0d --- /dev/null +++ b/runtimes/edge/services/error_handler.py @@ -0,0 +1,146 @@ +"""Unified error handling service for Universal Runtime. + +Provides consistent error handling patterns that were previously +duplicated across all endpoint handlers. +""" + +import functools +import logging +from collections.abc import Callable +from typing import Any, TypeVar + +from fastapi import HTTPException + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class UniversalRuntimeError(Exception): + """Base exception for Universal Runtime errors.""" + + def __init__(self, message: str, status_code: int = 500, code: str | None = None): + super().__init__(message) + self.message = message + self.status_code = status_code + self.code = code + + +class ModelNotFoundError(UniversalRuntimeError): + """Raised when a requested model is not found.""" + + def __init__(self, model_id: str, model_type: str = "model"): + super().__init__( + message=f"{model_type.capitalize()} not found: {model_id}", + status_code=404, + code="MODEL_NOT_FOUND", + ) + self.model_id = model_id + self.model_type = model_type + + +class ModelNotFittedError(UniversalRuntimeError): + """Raised when attempting to use an unfitted model.""" + + def __init__(self, model_id: str): + super().__init__( + message=f"Model '{model_id}' not fitted. Call fit() first or load a pre-trained model.", + status_code=400, + code="MODEL_NOT_FITTED", + ) + self.model_id = model_id + + +class ValidationError(UniversalRuntimeError): + """Raised for request validation errors.""" + + def __init__(self, message: str): + super().__init__(message=message, status_code=400, code="VALIDATION_ERROR") + + +class BackendNotInstalledError(UniversalRuntimeError): + """Raised when a required backend is not installed.""" + + def __init__(self, backend: str, install_hint: str | None = None): + message = f"Backend '{backend}' not installed." + if install_hint: + message += f" {install_hint}" + super().__init__(message=message, status_code=400, code="BACKEND_NOT_INSTALLED") + self.backend = backend + + +def handle_endpoint_errors( + endpoint_name: str, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Decorator for consistent endpoint error handling. + + Catches and formats errors in a consistent way across all endpoints. + + Args: + endpoint_name: Name of the endpoint for logging + + Returns: + Decorated function with error handling + + Usage: + @app.post("/v1/embeddings") + @handle_endpoint_errors("create_embeddings") + async def create_embeddings(request: EmbeddingRequest): + ... + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + try: + return await func(*args, **kwargs) + except HTTPException: + # Re-raise FastAPI HTTPExceptions as-is + raise + except UniversalRuntimeError as e: + # Convert our custom errors to HTTPException + logger.warning(f"Error in {endpoint_name}: {e.message}") + raise HTTPException(status_code=e.status_code, detail=e.message) from e + except ImportError as e: + # Handle missing dependencies + logger.error(f"Import error in {endpoint_name}: {e}") + raise HTTPException( + status_code=400, + detail=f"Required dependency not installed: {str(e)}", + ) from e + except ValueError as e: + # Handle validation errors + logger.warning(f"Validation error in {endpoint_name}: {e}") + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + # Log and wrap unexpected errors + logger.error(f"Error in {endpoint_name}: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail="An internal server error occurred.", + ) from e + + return wrapper + + return decorator + + +def format_error_response( + message: str, code: str | None = None, details: dict[str, Any] | None = None +) -> dict[str, Any]: + """Format an error response consistently. + + Args: + message: Error message + code: Optional error code + details: Optional additional details + + Returns: + Formatted error response dict + """ + response: dict[str, Any] = {"error": {"message": message}} + if code: + response["error"]["code"] = code + if details: + response["error"]["details"] = details + return response diff --git a/runtimes/edge/services/zenoh_ipc.py b/runtimes/edge/services/zenoh_ipc.py new file mode 100644 index 000000000..486ff400d --- /dev/null +++ b/runtimes/edge/services/zenoh_ipc.py @@ -0,0 +1,204 @@ +""" +Zenoh IPC interface for the edge runtime. + +Allows the orchestrator and other drone services to request LLM inference +over the Zenoh pub/sub bus (Unix socket IPC), matching the communication +pattern used by vision, comms, and flight-control. + +Topics: + local/llm/request — subscribe: incoming inference requests (JSON) + local/llm/response — publish: inference results (JSON) + local/llm/status — publish: periodic heartbeat with model info +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import time + +logger = logging.getLogger("edge-runtime.zenoh") + +ZENOH_ENDPOINT = os.getenv( + "ZENOH_ENDPOINT", "unixsock-stream//run/arc/zenoh.sock" +) +ZENOH_ENABLED = os.getenv("ZENOH_ENABLED", "true").lower() in ("true", "1", "yes") + +TOPIC_REQUEST = "local/llm/request" +TOPIC_RESPONSE = "local/llm/response" +TOPIC_STATUS = "local/llm/status" + +STATUS_INTERVAL_S = 5.0 + + +class ZenohIPC: + """Manages a Zenoh session for LLM inference over IPC.""" + + def __init__(self, inference_fn): + """ + Args: + inference_fn: async callable(request_dict) -> response content string. + Called for each incoming inference request. + """ + self._inference_fn = inference_fn + self._session = None + self._subscriber = None + self._loop: asyncio.AbstractEventLoop | None = None + self._tasks: list[asyncio.Task] = [] + self._pending_futures: list[asyncio.Future] = [] + + async def start(self) -> bool: + """Open Zenoh session and start subscriber + heartbeat tasks. + + Returns True if started successfully, False on failure (graceful degradation). + """ + if not ZENOH_ENABLED: + logger.info("Zenoh IPC disabled (ZENOH_ENABLED=false)") + return False + + try: + import zenoh + except ImportError: + logger.warning( + "eclipse-zenoh package not installed, Zenoh IPC unavailable" + ) + return False + + try: + config = zenoh.Config() + config.insert_json5( + "connect/endpoints", + json.dumps([ZENOH_ENDPOINT]), + ) + config.insert_json5("scouting/multicast/enabled", "false") + + self._session = zenoh.open(config) + logger.info("Zenoh session open (endpoint=%s)", ZENOH_ENDPOINT) + except Exception: + logger.warning( + "Failed to connect to Zenoh at %s — continuing HTTP-only", + ZENOH_ENDPOINT, + exc_info=True, + ) + return False + + self._loop = asyncio.get_event_loop() + self._subscriber = self._session.declare_subscriber( + TOPIC_REQUEST, self._on_request + ) + logger.info("Subscribed to %s", TOPIC_REQUEST) + self._tasks.append(asyncio.create_task(self._heartbeat_loop())) + return True + + async def stop(self): + """Cancel background tasks and close the Zenoh session.""" + for task in self._tasks: + task.cancel() + for task in self._tasks: + try: + await task + except asyncio.CancelledError: + pass # Expected: tasks were explicitly cancelled above + self._tasks.clear() + + # Cancel in-flight request handlers before closing the session + for future in list(self._pending_futures): + future.cancel() + self._pending_futures.clear() + + if self._subscriber is not None: + try: + self._subscriber.undeclare() + except Exception: + logger.warning("Error undeclaring Zenoh subscriber", exc_info=True) + self._subscriber = None + + if self._session is not None: + try: + self._session.close() + except Exception: + logger.warning("Error closing Zenoh session", exc_info=True) + self._session = None + logger.info("Zenoh session closed") + + # ------------------------------------------------------------------ + # Request handler + # ------------------------------------------------------------------ + + def _on_request(self, sample): + """Callback invoked by Zenoh subscriber on each request.""" + try: + payload = json.loads(bytes(sample.payload)) + future = asyncio.run_coroutine_threadsafe( + self._handle_request(payload), self._loop + ) + self._pending_futures.append(future) + def _remove_future(f): + try: + self._pending_futures.remove(f) + except ValueError: + pass # Already cleared by stop() + + future.add_done_callback(_remove_future) + except Exception: + logger.error("Error dispatching Zenoh request", exc_info=True) + + async def _handle_request(self, request: dict): + """Process a single inference request and publish the response.""" + request_id = request.get("request_id", "unknown") + model = request.get("model", "unknown") + t0 = time.monotonic() + + try: + content = await self._inference_fn(request) + inference_ms = int((time.monotonic() - t0) * 1000) + + response = { + "request_id": request_id, + "model": model, + "content": content, + "inference_time_ms": inference_ms, + "timestamp_ms": int(time.time() * 1000), + } + except Exception as exc: + inference_ms = int((time.monotonic() - t0) * 1000) + response = { + "request_id": request_id, + "model": model, + "content": "", + "error": "inference failed", + "inference_time_ms": inference_ms, + "timestamp_ms": int(time.time() * 1000), + } + logger.error("Inference failed for request %s: %s", request_id, exc) + + self._session.put( + TOPIC_RESPONSE, json.dumps(response).encode() + ) + + # ------------------------------------------------------------------ + # Status heartbeat + # ------------------------------------------------------------------ + + async def _heartbeat_loop(self): + """Publish periodic status to local/llm/status.""" + logger.info( + "Status heartbeat started (interval=%.1fs, topic=%s)", + STATUS_INTERVAL_S, + TOPIC_STATUS, + ) + try: + while True: + status = { + "service": "edge-runtime", + "status": "ready", + "timestamp_ms": int(time.time() * 1000), + } + self._session.put( + TOPIC_STATUS, json.dumps(status).encode() + ) + await asyncio.sleep(STATUS_INTERVAL_S) + except asyncio.CancelledError: + raise diff --git a/runtimes/edge/utils/__init__.py b/runtimes/edge/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/runtimes/edge/utils/context_calculator.py b/runtimes/edge/utils/context_calculator.py new file mode 100644 index 000000000..c2d81bc6b --- /dev/null +++ b/runtimes/edge/utils/context_calculator.py @@ -0,0 +1,477 @@ +"""Context size calculator for GGUF models. + +Determines optimal context window size based on: +1. User configuration (highest priority) +2. Available memory and model size +3. Model family defaults from config file +4. Fallback defaults +""" + +import fnmatch +import logging +from pathlib import Path + +import psutil +import yaml + +from utils.gguf_metadata_cache import get_gguf_metadata_cached + +logger = logging.getLogger(__name__) + +# Cache for config file +_config_cache: dict | None = None + + +def get_gguf_metadata(gguf_path: str) -> dict: + """Read GGUF file metadata without loading the full model. + + Uses the shared GGUF metadata cache to avoid redundant file reads. + The cache is populated once per file and reused by context_calculator, + jinja_tools, and other modules. + + Args: + gguf_path: Path to .gguf file + + Returns: + dict with metadata including: + - file_size_bytes: Size of the GGUF file in bytes + - file_size_mb: Size in megabytes (for logging) + - n_ctx_train: Training context size (if available) + + Raises: + FileNotFoundError: If GGUF file doesn't exist + """ + # Use shared cache - single read for all metadata needs + cached = get_gguf_metadata_cached(gguf_path) + + # Return in legacy format for backward compatibility + return { + "file_size_bytes": cached.file_size_bytes, + "file_size_mb": cached.file_size_mb, + "n_ctx_train": cached.n_ctx_train, + "n_layer": cached.n_layer, + "n_head_kv": cached.n_head_kv, + "head_k_size": cached.head_k_size, + "head_v_size": cached.head_v_size, + } + + +def get_available_memory(device: str, gpu_index: int | None = None) -> int: + """Get available memory in bytes for the device. + + Args: + device: Target device ("cuda", "mps", or "cpu") + gpu_index: Specific CUDA GPU index. If None, uses GPU 0. + + Returns: + Available memory in bytes + + Notes: + - For CUDA: Returns free GPU memory on the specified device + - For MPS/CPU: Returns available system RAM + """ + try: + import torch + except ImportError: + torch = None # type: ignore[assignment] + + try: + if torch is not None and device == "cuda" and torch.cuda.is_available(): + idx = gpu_index if gpu_index is not None else 0 + free, total = torch.cuda.mem_get_info(idx) + logger.debug( + f"CUDA GPU {idx} memory: {free / (1024**3):.2f} GB free / " + f"{total / (1024**3):.2f} GB total" + ) + return free + else: + # For CPU and MPS, use system RAM + # Get available (not total) to be conservative + vm = psutil.virtual_memory() + available_memory = vm.available + logger.debug( + f"System memory - Total: {vm.total / (1024**3):.2f} GB, " + f"Available: {available_memory / (1024**3):.2f} GB" + ) + return available_memory + except Exception as e: + logger.warning(f"Error detecting memory for device {device}: {e}") + # Fallback to conservative estimate (4GB) + return 4 * 1024 * 1024 * 1024 + + +def compute_kv_bytes_per_token( + n_layer: int, + n_head_kv: int, + head_k_size: int, + head_v_size: int, +) -> int: + """Compute exact KV cache bytes per token from model architecture. + + The KV cache stores key and value tensors for every layer, using f16 precision + (2 bytes per element). This is the dominant memory cost that scales with context. + + Args: + n_layer: Number of transformer layers (block_count) + n_head_kv: Number of key-value attention heads + head_k_size: Dimension of each key head + head_v_size: Dimension of each value head + + Returns: + Bytes of KV cache needed per token of context + """ + # K cache per token: n_layer * n_head_kv * head_k_size * sizeof(f16) + # V cache per token: n_layer * n_head_kv * head_v_size * sizeof(f16) + return n_layer * n_head_kv * (head_k_size + head_v_size) * 2 + + +# Fallback estimate when GGUF architecture metadata isn't available. +# Deliberately conservative (overestimates cost) to prevent OOM. +# Actual KV cache costs range from ~18 KB/token (1.5B) to ~320 KB/token (70B). +# 256 KB covers most 7B+ models safely; smaller models just get less context than +# they could handle, which is preferable to OOM. +_FALLBACK_BYTES_PER_TOKEN = 256 * 1024 # 256 KB + + +def compute_max_context( + model_size_bytes: int, + available_memory_bytes: int, + memory_factor: float = 0.8, + max_context_cap: int = 131072, + n_layer: int | None = None, + n_head_kv: int | None = None, + head_k_size: int | None = None, + head_v_size: int | None = None, +) -> int: + """Compute maximum safe context size based on available memory. + + Uses model architecture metadata (when available) to compute the exact + KV cache cost per token, rather than relying on a fixed estimate. + + Args: + model_size_bytes: Size of model file in bytes + available_memory_bytes: Available memory on target device + memory_factor: Fraction of available memory to use (default 0.8) + max_context_cap: Hard upper limit for context size (default 131072/128K). + Most models don't support more than 128K context even with + sufficient memory. + n_layer: Number of transformer layers (from GGUF metadata) + n_head_kv: Number of key-value attention heads (from GGUF metadata) + head_k_size: Dimension of each key head (from GGUF metadata) + head_v_size: Dimension of each value head (from GGUF metadata) + + Returns: + Maximum safe context size (number of tokens) + """ + # Calculate usable memory after loading model + usable_memory = (available_memory_bytes * memory_factor) - model_size_bytes + + if usable_memory <= 0: + logger.warning( + f"Model size ({model_size_bytes / (1024**3):.2f} GB) exceeds " + f"available memory budget. Using minimal context size." + ) + return 512 # Minimal context + + # Compute per-token memory cost + has_arch_params = all( + v is not None for v in [n_layer, n_head_kv, head_k_size, head_v_size] + ) + if has_arch_params: + kv_bytes = compute_kv_bytes_per_token( + n_layer, n_head_kv, head_k_size, head_v_size + ) + # Add 30% overhead for compute buffers and activation tensors + bytes_per_token = int(kv_bytes * 1.3) + logger.debug( + f"KV cache from architecture: {kv_bytes} bytes/token " + f"(n_layer={n_layer}, n_head_kv={n_head_kv}, " + f"head_k={head_k_size}, head_v={head_v_size}), " + f"with overhead: {bytes_per_token} bytes/token" + ) + else: + bytes_per_token = _FALLBACK_BYTES_PER_TOKEN + logger.debug( + f"Architecture metadata unavailable, using fallback: " + f"{bytes_per_token} bytes/token" + ) + + max_context = int(usable_memory / bytes_per_token) + + # Apply hard cap - most models don't support extremely large contexts + # even if memory would allow it + if max_context > max_context_cap: + logger.debug( + f"Computed context {max_context} exceeds cap {max_context_cap}, capping" + ) + max_context = max_context_cap + + # Round down to nearest power of 2 for better memory alignment + # Common sizes: 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072 + power_of_2 = 1 + while power_of_2 * 2 <= max_context: + power_of_2 *= 2 + + logger.debug( + f"Memory calculation: available={available_memory_bytes / (1024**3):.2f}GB, " + f"model={model_size_bytes / (1024**3):.2f}GB, " + f"usable={usable_memory / (1024**3):.2f}GB, " + f"bytes_per_token={bytes_per_token}, " + f"max_ctx_computed={max_context}, " + f"max_ctx_aligned={power_of_2}" + ) + + return power_of_2 + + +def load_model_context_config() -> dict: + """Load model_context_defaults.yaml configuration. + + Caches the configuration to avoid repeated file I/O. + + Returns: + dict with 'memory_usage_factor' and 'model_defaults' keys + + Raises: + FileNotFoundError: If config file doesn't exist + yaml.YAMLError: If config file is malformed + """ + global _config_cache + + if _config_cache is not None: + return _config_cache + + # Find config file relative to this module + config_path = ( + Path(__file__).parent.parent / "config" / "model_context_defaults.yaml" + ) + + if not config_path.exists(): + raise FileNotFoundError( + f"Context config file not found: {config_path}. " + "Create config/model_context_defaults.yaml" + ) + + logger.debug(f"Loading context config from: {config_path}") + + with open(config_path) as f: + config = yaml.safe_load(f) + + # Validate config structure + if "model_defaults" not in config or not isinstance(config["model_defaults"], list): + raise ValueError( + "Invalid config: 'model_defaults' must be a list of pattern entries" + ) + + _config_cache = config + logger.debug(f"Loaded {len(config['model_defaults'])} model patterns from config") + return config + + +def match_model_pattern(model_id: str, config: dict) -> int | None: + """Match model_id against patterns in config using fnmatch. + + Patterns are checked in order, with more specific patterns + listed first. Returns the n_ctx for the first matching pattern. + + Args: + model_id: HuggingFace model identifier (e.g., "unsloth/Qwen2.5-Coder-1.5B-Instruct-GGUF") + config: Configuration dict from load_model_context_config() + + Returns: + n_ctx value for first matching pattern, or None if no match + + Examples: + >>> config = load_model_context_config() + >>> match_model_pattern("unsloth/Qwen2.5-Coder-1.5B-Instruct-GGUF", config) + 32768 + >>> match_model_pattern("*/Llama-3-8B-GGUF", config) + 8192 + """ + model_defaults = config.get("model_defaults", []) + + for entry in model_defaults: + pattern = entry.get("pattern") + n_ctx = entry.get("n_ctx") + + if not pattern or n_ctx is None: + logger.warning(f"Invalid config entry: {entry}") + continue + + if fnmatch.fnmatch(model_id, pattern): + notes = entry.get("notes", "") + logger.info( + f"Matched model '{model_id}' to pattern '{pattern}': " + f"n_ctx={n_ctx} ({notes})" + ) + return n_ctx + + logger.warning(f"No pattern match found for model: {model_id}") + return None + + +def get_default_context_size( + model_id: str, + gguf_path: str, + device: str, + config_n_ctx: int | None = None, + gpu_index: int | None = None, + available_memory_override: int | None = None, +) -> tuple[int, list[str]]: + """Determine context size with four-tier priority system. + + Priority order (highest to lowest): + 1. config_n_ctx (from llamafarm.yaml via API) - user's explicit choice + 2. Model's n_ctx_train (training context) - what the model was designed for + 3. Pattern match from model_context_defaults.yaml - known model defaults + 4. Computed max from memory constraints - hardware limitation + 5. Fallback default (2048) - safe conservative value + + All choices are capped by available memory to prevent OOM errors. + + Args: + model_id: HuggingFace model identifier + gguf_path: Path to GGUF file + device: Target device ("cuda", "mps", "cpu") + config_n_ctx: Optional explicit context size from config + gpu_index: Specific CUDA GPU index for memory queries. If None, uses GPU 0. + available_memory_override: Pre-computed available memory in bytes. + When provided, skips the ``get_available_memory()`` query. Used + for multi-GPU splits where the effective memory is the combined + free VRAM across all participating devices. + + Returns: + tuple of (final_n_ctx, warnings_list) + - final_n_ctx: Determined context size to use + - warnings_list: List of warning messages (empty if none) + + Examples: + >>> n_ctx, warnings = get_default_context_size( + ... "unsloth/Qwen2.5-Coder-1.5B-Instruct-GGUF", + ... "/path/to/model.gguf", + ... "mps", + ... config_n_ctx=32768 + ... ) + >>> n_ctx + 32768 # or lower if memory constrained + """ + warnings = [] + + try: + # Load configuration + config = load_model_context_config() + memory_factor = config.get("memory_usage_factor", 0.8) + + # Get model metadata and compute memory constraints + metadata = get_gguf_metadata(gguf_path) + if available_memory_override is not None: + available_memory = available_memory_override + else: + available_memory = get_available_memory(device, gpu_index=gpu_index) + max_context_from_memory = compute_max_context( + metadata["file_size_bytes"], + available_memory, + memory_factor, + n_layer=metadata.get("n_layer"), + n_head_kv=metadata.get("n_head_kv"), + head_k_size=metadata.get("head_k_size"), + head_v_size=metadata.get("head_v_size"), + ) + + logger.info( + f"Memory-based max context for {model_id}: {max_context_from_memory} " + f"(model size: {metadata['file_size_mb']:.1f} MB, " + f"available memory: {available_memory / (1024**3):.2f} GB)" + ) + + # Get model's training context size if available + n_ctx_train = metadata.get("n_ctx_train") + if n_ctx_train: + logger.info(f"Model trained with context size: {n_ctx_train}") + + # Get pattern-based default + pattern_n_ctx = match_model_pattern(model_id, config) + + # Determine final context size based on priority + if config_n_ctx is not None: + # Priority 1: User specified a value - use it but check against memory limit + if config_n_ctx > max_context_from_memory: + warning_msg = ( + f"Requested context size {config_n_ctx} exceeds computed maximum " + f"{max_context_from_memory} based on available memory " + f"({available_memory / (1024**3):.2f} GB). " + f"Using {max_context_from_memory} instead." + ) + warnings.append(warning_msg) + final_n_ctx = max_context_from_memory + else: + final_n_ctx = config_n_ctx + logger.info(f"Using configured context size: {final_n_ctx}") + + elif n_ctx_train is not None: + # Priority 2: Use model's training context, but respect memory limit + if n_ctx_train > max_context_from_memory: + warning_msg = ( + f"Model training context {n_ctx_train} exceeds computed maximum " + f"{max_context_from_memory} based on available memory. " + f"Using {max_context_from_memory} to prevent OOM." + ) + warnings.append(warning_msg) + final_n_ctx = max_context_from_memory + else: + final_n_ctx = n_ctx_train + logger.info(f"Using model's training context size: {final_n_ctx}") + + elif pattern_n_ctx is not None: + # Priority 3: Use pattern match, but respect memory limit + if pattern_n_ctx > max_context_from_memory: + warning_msg = ( + f"Pattern default context size {pattern_n_ctx} exceeds computed maximum " + f"{max_context_from_memory} based on available memory. " + f"Using {max_context_from_memory} instead." + ) + warnings.append(warning_msg) + final_n_ctx = max_context_from_memory + else: + final_n_ctx = pattern_n_ctx + logger.info(f"Using pattern-matched context size: {final_n_ctx}") + + else: + # Priority 4: No other source - use computed max or fallback + if max_context_from_memory >= 2048: + final_n_ctx = max_context_from_memory + logger.info(f"Using computed max context: {final_n_ctx}") + else: + final_n_ctx = 2048 + warning_msg = ( + f"Low memory detected. Using fallback context size: {final_n_ctx}" + ) + warnings.append(warning_msg) + + # Final sanity check - ensure we have at least 512 tokens + if final_n_ctx < 512: + warning_msg = ( + f"Computed context size {final_n_ctx} is very low. " + "Using minimum of 512 tokens." + ) + warnings.append(warning_msg) + final_n_ctx = 512 + + return final_n_ctx, warnings + + except Exception as e: + # If anything fails, use safe fallback + error_msg = f"Error computing context size: {e}. Using fallback of 2048." + logger.error(error_msg, exc_info=True) + warnings.append(error_msg) + return 2048, warnings + + +def clear_config_cache(): + """Clear the configuration cache. + + Useful for testing or when config file is modified at runtime. + """ + global _config_cache + _config_cache = None + logger.debug("Context config cache cleared") diff --git a/runtimes/edge/utils/context_manager.py b/runtimes/edge/utils/context_manager.py new file mode 100644 index 000000000..5f1167c25 --- /dev/null +++ b/runtimes/edge/utils/context_manager.py @@ -0,0 +1,506 @@ +"""Context management and truncation strategies. + +Provides context window management for LLM conversations, including +validation, truncation, and multiple strategies for handling context overflow. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .token_counter import TokenCounter + +logger = logging.getLogger(__name__) + + +class TruncationStrategy(Enum): + """Available truncation strategies for context overflow.""" + + # Remove oldest messages first (simple sliding window) + SLIDING_WINDOW = "sliding_window" + + # Keep system messages, slide user/assistant messages + KEEP_SYSTEM_SLIDING = "keep_system" + + # Keep system, first exchange, and recent messages; remove middle + MIDDLE_OUT = "middle_out" + + # Summarize older messages using an LLM (requires summarizer) + SUMMARIZE = "summarize" + + +@dataclass +class ContextBudget: + """Token budget allocation for context window. + + Splits the context window into regions: + - prompt: tokens for input messages + - completion: tokens reserved for model output + - safety_margin: buffer to avoid edge cases + """ + + total_context: int + max_prompt_tokens: int + reserved_completion: int + safety_margin: int + + @classmethod + def from_context_size( + cls, + n_ctx: int, + max_completion_tokens: int = 512, + safety_margin_pct: float = 0.05, + ) -> ContextBudget: + """Create a budget from model's context size. + + Args: + n_ctx: Model's total context window size in tokens. + max_completion_tokens: Tokens to reserve for output (default: 512). + safety_margin_pct: Percentage of context as safety buffer (default: 5%). + + Returns: + A ContextBudget instance with calculated allocations. + """ + safe_n_ctx = max(1, n_ctx) + safety_margin = max(0, int(safe_n_ctx * safety_margin_pct)) + + # Keep enough room for both prompt and completion even on small context windows. + max_safe_margin = max(0, safe_n_ctx - 2) + safety_margin = min(safety_margin, max_safe_margin) + + completion_target = max(1, max_completion_tokens) + available_after_safety = max(2, safe_n_ctx - safety_margin) + + # Avoid hard-reserving large completion windows on tiny contexts. + adaptive_completion_cap = max(1, available_after_safety // 2) + reserved_completion = min(completion_target, adaptive_completion_cap) + max_prompt = max(1, safe_n_ctx - safety_margin - reserved_completion) + + return cls( + total_context=safe_n_ctx, + max_prompt_tokens=max_prompt, + reserved_completion=reserved_completion, + safety_margin=safety_margin, + ) + + +@dataclass +class ContextUsage: + """Context usage information for API responses. + + Provides visibility into how the context window is being used, + including whether truncation was applied. + """ + + total_context: int + prompt_tokens: int + available_for_completion: int + truncated: bool = False + truncated_messages: int = 0 + strategy_used: str | None = None + + +class ContextManager: + """Manages context window and applies truncation strategies. + + Validates that messages fit within the context budget and applies + truncation strategies when needed to prevent overflow errors. + """ + + def __init__( + self, + token_counter: TokenCounter, + budget: ContextBudget, + default_strategy: TruncationStrategy = TruncationStrategy.SUMMARIZE, + ): + """Initialize context manager. + + Args: + token_counter: TokenCounter instance for counting tokens. + budget: ContextBudget defining token allocations. + default_strategy: Default truncation strategy to use. + """ + self._counter = token_counter + self._budget = budget + self._default_strategy = default_strategy + + @property + def budget(self) -> ContextBudget: + """Get the context budget.""" + return self._budget + + def _available_for_completion(self, prompt_tokens: int) -> int: + """Calculate completion tokens available under current budget assumptions.""" + available = ( + self._budget.total_context - prompt_tokens - self._budget.safety_margin + ) + return max(0, min(self._budget.reserved_completion, available)) + + def validate_messages(self, messages: list[dict]) -> ContextUsage: + """Validate messages fit within context budget. + + Returns usage info without modifying messages. + + Args: + messages: List of chat messages to validate. + + Returns: + ContextUsage with token counts and overflow status. + """ + prompt_tokens = self._counter.estimate_prompt_tokens(messages) + + return ContextUsage( + total_context=self._budget.total_context, + prompt_tokens=prompt_tokens, + available_for_completion=self._available_for_completion(prompt_tokens), + truncated=False, + truncated_messages=0, + strategy_used=None, + ) + + def needs_truncation(self, messages: list[dict]) -> bool: + """Check if messages exceed the context budget. + + Args: + messages: List of chat messages. + + Returns: + True if truncation is needed. + """ + prompt_tokens = self._counter.estimate_prompt_tokens(messages) + return prompt_tokens > self._budget.max_prompt_tokens + + def truncate_if_needed( + self, + messages: list[dict], + strategy: TruncationStrategy | None = None, + ) -> tuple[list[dict], ContextUsage]: + """Truncate messages to fit context budget if needed. + + Args: + messages: List of chat messages. + strategy: Override default truncation strategy. + + Returns: + Tuple of (possibly truncated messages, context usage info). + """ + strategy = strategy or self._default_strategy + prompt_tokens = self._counter.estimate_prompt_tokens(messages) + + if prompt_tokens <= self._budget.max_prompt_tokens: + # No truncation needed + return messages, ContextUsage( + total_context=self._budget.total_context, + prompt_tokens=prompt_tokens, + available_for_completion=self._available_for_completion(prompt_tokens), + truncated=False, + truncated_messages=0, + strategy_used=None, + ) + + # Deep copy to avoid modifying original + # Use JSON for Pydantic-safe deep copy + messages = json.loads(json.dumps(messages, default=str)) + original_count = len(messages) + + # Apply truncation strategy + if strategy == TruncationStrategy.SLIDING_WINDOW: + truncated = self._sliding_window(messages) + elif strategy == TruncationStrategy.KEEP_SYSTEM_SLIDING: + truncated = self._keep_system_sliding(messages) + elif strategy == TruncationStrategy.MIDDLE_OUT: + truncated = self._middle_out(messages) + elif strategy == TruncationStrategy.SUMMARIZE: + # Summarization is async and handled separately + # Fall back to keep_system_sliding for sync truncation + logger.warning( + "Summarization strategy requires async handling, " + "falling back to keep_system_sliding" + ) + truncated = self._keep_system_sliding(messages) + else: + # Default fallback + truncated = self._keep_system_sliding(messages) + + new_tokens = self._counter.estimate_prompt_tokens(truncated) + messages_removed = original_count - len(truncated) + + logger.info( + f"Context truncated: {original_count} -> {len(truncated)} messages " + f"({prompt_tokens} -> {new_tokens} tokens), strategy={strategy.value}" + ) + + return truncated, ContextUsage( + total_context=self._budget.total_context, + prompt_tokens=new_tokens, + available_for_completion=self._available_for_completion(new_tokens), + truncated=True, + truncated_messages=messages_removed, + strategy_used=strategy.value, + ) + + def _sliding_window(self, messages: list[dict]) -> list[dict]: + """Remove oldest messages until context fits. + + Simple strategy that removes messages from the beginning, + regardless of role. Falls back to content truncation if + needed. + + Args: + messages: List of messages (will be modified). + + Returns: + Truncated messages. + """ + result = list(messages) + + while ( + len(result) > 1 + and self._counter.estimate_prompt_tokens(result) + > self._budget.max_prompt_tokens + ): + result.pop(0) + + # If still over budget (single huge message), truncate content + if ( + self._counter.estimate_prompt_tokens(result) + > self._budget.max_prompt_tokens + ): + logger.warning( + "Message removal insufficient in sliding_window, " + "applying content truncation" + ) + result = self._truncate_message_contents(result) + + return result + + def _keep_system_sliding(self, messages: list[dict]) -> list[dict]: + """Keep system prompts, slide user/assistant messages. + + Preserves all system messages and removes oldest non-system + messages until context fits. If still over budget after removing + all but one message, truncates individual message content. + + Args: + messages: List of messages (will be modified). + + Returns: + Truncated messages. + """ + system_msgs = [m for m in messages if m.get("role") == "system"] + other_msgs = [m for m in messages if m.get("role") != "system"] + + # Calculate tokens for system messages + system_tokens = self._counter.estimate_prompt_tokens(system_msgs) + available_for_others = self._budget.max_prompt_tokens - system_tokens + + # Remove oldest non-system messages until fits + while ( + len(other_msgs) > 1 + and self._counter.estimate_prompt_tokens(other_msgs) > available_for_others + ): + other_msgs.pop(0) + + result = system_msgs + other_msgs + + # If still over budget, apply aggressive content truncation + if ( + self._counter.estimate_prompt_tokens(result) + > self._budget.max_prompt_tokens + ): + logger.warning("Message removal insufficient, applying content truncation") + result = self._truncate_message_contents(result) + + return result + + def _middle_out(self, messages: list[dict]) -> list[dict]: + """Keep system, first exchange, and recent messages; remove middle. + + Useful for preserving initial context (task setup) and recent + conversation while removing less relevant middle content. + Falls back to content truncation if needed. + + Args: + messages: List of messages (will be modified). + + Returns: + Truncated messages. + """ + if len(messages) <= 3: + result = list(messages) + else: + system_msgs = [m for m in messages if m.get("role") == "system"] + other_msgs = [m for m in messages if m.get("role") != "system"] + + if len(other_msgs) <= 2: + result = list(messages) + else: + # Keep first non-system message and last N messages + first_msg = [other_msgs[0]] + remaining = other_msgs[1:] + + # Remove from the beginning of remaining (oldest after first) + # until we fit within budget + while ( + len(remaining) > 1 + and self._counter.estimate_prompt_tokens( + system_msgs + first_msg + remaining + ) + > self._budget.max_prompt_tokens + ): + remaining.pop(0) + + result = system_msgs + first_msg + remaining + + # If still over budget (huge messages), truncate content + if ( + self._counter.estimate_prompt_tokens(result) + > self._budget.max_prompt_tokens + ): + logger.warning( + "Message removal insufficient in middle_out, " + "applying content truncation" + ) + result = self._truncate_message_contents(result) + + return result + + def _truncate_message_contents(self, messages: list[dict]) -> list[dict]: + """Truncate individual message contents to fit context budget. + + This is a last resort when removing whole messages isn't enough + (e.g., when a single message exceeds the entire context budget). + + Strategy: + 1. Calculate how much we're over budget + 2. Find the largest messages and truncate them proportionally + 3. Preserve the last user message as much as possible (most recent query) + + Args: + messages: List of messages to truncate. + + Returns: + Messages with truncated content. + """ + # Use JSON for Pydantic-safe deep copy + result = json.loads(json.dumps(messages, default=str)) + max_tokens = self._budget.max_prompt_tokens + + # Calculate current usage and how much we need to cut + current_tokens = self._counter.estimate_prompt_tokens(result) + if current_tokens <= max_tokens: + return result + + tokens_to_cut = current_tokens - max_tokens + 100 # Extra buffer + + logger.info( + f"Content truncation: need to cut ~{tokens_to_cut} tokens " + f"from {current_tokens} total" + ) + + # Find messages with content, sorted by size (largest first) + # Skip the last user message if possible (it's the current query) + messages_with_size = [] + for i, msg in enumerate(result): + content = msg.get("content") or "" + if not content: + continue + tokens = self._counter.count_tokens(content) + # Mark if this is the last user message + is_last_user = msg.get("role") == "user" and all( + m.get("role") != "user" for m in result[i + 1 :] + ) + messages_with_size.append((i, tokens, is_last_user)) + + # Sort by tokens descending, but keep last user message at end + messages_with_size.sort(key=lambda x: (x[2], -x[1])) + + # Truncate largest messages first + tokens_cut = 0 + for idx, msg_tokens, _is_last_user in messages_with_size: + if tokens_cut >= tokens_to_cut: + break + + content = result[idx].get("content", "") + if not content or msg_tokens < 100: + continue + + # Calculate how much to keep + # For very large messages, be more aggressive + if msg_tokens > 10000: + # Keep at most 10% or 500 tokens + keep_tokens = min(int(msg_tokens * 0.1), 500) + elif msg_tokens > 1000: + # Keep at most 30% or 300 tokens + keep_tokens = min(int(msg_tokens * 0.3), 300) + else: + # Keep at most 50% + keep_tokens = int(msg_tokens * 0.5) + + # Truncate the content + truncated_content = self._counter.truncate_to_tokens(content, keep_tokens) + cut_amount = msg_tokens - self._counter.count_tokens(truncated_content) + + result[idx]["content"] = ( + truncated_content + "\n\n[... content truncated ...]" + ) + + logger.debug( + f"Truncated message {idx} (role={result[idx].get('role')}): " + f"{msg_tokens} -> {keep_tokens} tokens" + ) + + tokens_cut += cut_amount + + # Verify we're now under budget, keep truncating if needed + final_tokens = self._counter.estimate_prompt_tokens(result) + emergency_iterations = 0 + max_emergency_iterations = len(result) * 2 # Safety limit + + while ( + final_tokens > max_tokens + and emergency_iterations < max_emergency_iterations + ): + emergency_iterations += 1 + logger.warning( + f"Content truncation incomplete: {final_tokens} > {max_tokens}. " + f"Emergency truncation iteration {emergency_iterations}." + ) + + # Find the largest message and truncate it aggressively + largest_idx = -1 + largest_tokens = 0 + for i, msg in enumerate(result): + content = msg.get("content") or "" + if content: + tokens = self._counter.count_tokens(content) + if tokens > largest_tokens: + largest_tokens = tokens + largest_idx = i + + if largest_idx < 0 or largest_tokens <= 50: + # No more content to truncate + logger.error( + f"Cannot reduce context further. Remaining: {final_tokens} tokens" + ) + break + + # Truncate the largest message to 50 tokens + content = result[largest_idx]["content"] + result[largest_idx]["content"] = ( + self._counter.truncate_to_tokens(content, 50) + + "\n[... heavily truncated ...]" + ) + + final_tokens = self._counter.estimate_prompt_tokens(result) + logger.info( + f"Emergency truncated message {largest_idx}: " + f"{largest_tokens} -> ~50 tokens. New total: {final_tokens}" + ) + + return result diff --git a/runtimes/edge/utils/context_summarizer.py b/runtimes/edge/utils/context_summarizer.py new file mode 100644 index 000000000..a96be0c06 --- /dev/null +++ b/runtimes/edge/utils/context_summarizer.py @@ -0,0 +1,239 @@ +"""Context summarization using an LLM. + +Provides LLM-based summarization of conversation history to preserve +semantic meaning while dramatically reducing token count. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from models.gguf_language_model import GGUFLanguageModel + +logger = logging.getLogger(__name__) + + +# Prompt for summarizing conversation history +SUMMARIZE_PROMPT = """Summarize the following conversation concisely, preserving: +- Key facts and decisions made +- Important context the assistant needs to remember +- Any commitments or action items +- Technical details that may be referenced later + +Be concise but complete. Write in third person (e.g., "The user asked about X. The assistant explained Y."). + +Conversation: +{conversation} + +Summary:""" + + +class ContextSummarizer: + """Summarizes conversation history using an LLM. + + When context exceeds budget, this class can compress older messages + into a single summary message, preserving semantic meaning while + dramatically reducing token count. + + Uses the server's model loading mechanism to benefit from caching + and proper lifecycle management. + """ + + # Default model for summarization (small, fast, good at instruction following) + # Qwen3 has better instruction following than Qwen2.5 for summarization tasks + DEFAULT_MODEL = "Qwen/Qwen3-1.7B-GGUF" + DEFAULT_QUANTIZATION = "Q4_K_M" + + # Default number of recent exchanges to preserve + DEFAULT_KEEP_RECENT = 4 + + def __init__( + self, + model_id: str | None = None, + quantization: str | None = None, + keep_recent: int | None = None, + load_language: Callable | None = None, + ): + """Initialize context summarizer. + + Args: + model_id: HuggingFace model ID for summarization (default: Qwen2.5-1.5B). + quantization: GGUF quantization preference (default: Q4_K_M). + keep_recent: Number of recent exchanges to preserve (default: 4). + load_language: Model loader function (uses server's loader for caching). + """ + self._model_id = model_id or self.DEFAULT_MODEL + self._quantization = quantization or self.DEFAULT_QUANTIZATION + # Use explicit None check to allow keep_recent=0 + self._keep_recent = ( + keep_recent if keep_recent is not None else self.DEFAULT_KEEP_RECENT + ) + self._load_language = load_language + self._model: GGUFLanguageModel | None = None + + async def ensure_model_loaded(self) -> None: + """Load the summarization model using the server's caching mechanism.""" + if self._model is not None: + return + + # Use the server's load_language function for proper caching + if self._load_language is not None: + logger.info( + f"Loading summarization model via server cache: {self._model_id}" + ) + self._model = await self._load_language( + self._model_id, + n_ctx=4096, + preferred_quantization=self._quantization, + ) + logger.info("Summarization model loaded (cached by server)") + else: + # Fallback: import server's loader directly + try: + from server import load_language + + logger.info(f"Loading summarization model: {self._model_id}") + self._model = await load_language( + self._model_id, + n_ctx=4096, + preferred_quantization=self._quantization, + ) + logger.info("Summarization model loaded successfully") + except ImportError: + # Last resort: create model directly (won't be cached) + from models.gguf_language_model import GGUFLanguageModel + + logger.warning( + f"Loading summarization model directly (not cached): {self._model_id}" + ) + self._model = GGUFLanguageModel( + model_id=self._model_id, + device="cpu", + n_ctx=4096, + preferred_quantization=self._quantization, + ) + await self._model.load() + + async def summarize_messages( + self, + messages: list[dict], + keep_recent: int | None = None, + ) -> list[dict]: + """Summarize older messages, keeping recent ones intact. + + Args: + messages: List of chat messages. + keep_recent: Number of recent exchanges to keep (default: 4). + + Returns: + Messages with older content summarized into a single message. + """ + # Use explicit None check to allow keep_recent=0 + if keep_recent is None: + keep_recent = self._keep_recent + + # Separate system messages from conversation + system_msgs = [m for m in messages if m.get("role") == "system"] + other_msgs = [m for m in messages if m.get("role") != "system"] + + # Check if we have enough messages to summarize + # keep_recent * 2 because each exchange is user + assistant + min_messages = keep_recent * 2 + if len(other_msgs) <= min_messages: + logger.debug("Not enough messages to summarize") + return messages + + # Split into old (to summarize) and recent (to keep) + # Handle min_messages=0 specially since [:-0] returns [] and [-0:] returns all + if min_messages == 0: + to_summarize = other_msgs + to_keep = [] + else: + to_summarize = other_msgs[:-min_messages] + to_keep = other_msgs[-min_messages:] + + logger.info( + f"Summarizing {len(to_summarize)} messages, keeping {len(to_keep)} recent" + ) + + # Ensure model is loaded + await self.ensure_model_loaded() + + # Generate summary + summary = await self._generate_summary(to_summarize) + + # Create summary message as a system-level context + summary_msg = { + "role": "system", + "content": f"[Conversation Summary]\n{summary}", + } + + # Return: original system + summary + recent messages + return system_msgs + [summary_msg] + to_keep + + async def _generate_summary(self, messages: list[dict]) -> str: + """Generate a summary of the given messages. + + Args: + messages: List of messages to summarize. + + Returns: + Summary text. + """ + if self._model is None: + raise RuntimeError("Summarization model not loaded") + + # Format messages for summarization + conversation_text = self._format_for_summary(messages) + + # Build the summarization prompt + prompt = SUMMARIZE_PROMPT.format(conversation=conversation_text) + + # Generate summary using the model + summary = await self._model.generate( + messages=[{"role": "user", "content": prompt}], + max_tokens=512, # Limit summary length + temperature=0.3, # Lower temperature for more focused summary + ) + + return summary.strip() + + def _format_for_summary(self, messages: list[dict]) -> str: + """Format messages for summarization prompt. + + Args: + messages: List of messages. + + Returns: + Formatted conversation text. + """ + parts = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + if not content: + continue + + # Capitalize role for readability + role_label = role.capitalize() + if role == "assistant": + role_label = "Assistant" + elif role == "user": + role_label = "User" + elif role == "tool": + role_label = "Tool Result" + + # Truncate very long messages for the summary input + if len(content) > 1000: + content = content[:1000] + "..." + + parts.append(f"{role_label}: {content}") + + return "\n\n".join(parts) + + # Note: No explicit unload() method - the model is managed by the server's + # cache and will be evicted based on the normal cache TTL policy. diff --git a/runtimes/edge/utils/device.py b/runtimes/edge/utils/device.py new file mode 100644 index 000000000..c6a328841 --- /dev/null +++ b/runtimes/edge/utils/device.py @@ -0,0 +1,9 @@ +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.device import ( + get_device_info, + get_gguf_gpu_layers, + get_optimal_device, + is_torch_available, +) + +__all__ = ["get_optimal_device", "get_device_info", "is_torch_available", "get_gguf_gpu_layers"] diff --git a/runtimes/edge/utils/file_handler.py b/runtimes/edge/utils/file_handler.py new file mode 100644 index 000000000..f1f948dee --- /dev/null +++ b/runtimes/edge/utils/file_handler.py @@ -0,0 +1,213 @@ +""" +Shared file handling utilities for Edge Runtime. + +Provides: +- File upload with automatic base64 encoding +- Temporary file storage with TTL +- Support for images (no PDF support — edge doesn't process PDFs) +""" + +import asyncio +import base64 +import hashlib +import logging +import mimetypes +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path + +logger = logging.getLogger(__name__) + +# File storage TTL (seconds) - files are cleaned up after this time +FILE_TTL = 300 # 5 minutes + +# Supported file types (no PDF on edge) +IMAGE_TYPES = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tiff", ".tif"} +ALL_SUPPORTED = IMAGE_TYPES + + +@dataclass +class StoredFile: + """A file stored in the temporary cache.""" + + id: str + filename: str + content_type: str + size: int + base64_data: str # Base64-encoded content + created_at: float = field(default_factory=time.time) + + +# In-memory file storage +_file_cache: dict[str, StoredFile] = {} +_cleanup_task: asyncio.Task | None = None + + +def _generate_file_id(content: bytes, filename: str) -> str: + """Generate a unique file ID based on content hash and UUID.""" + content_hash = hashlib.sha256(content[:1024]).hexdigest()[:8] + unique_id = uuid.uuid4().hex[:8] + return f"file_{content_hash}_{unique_id}" + + +async def store_file( + content: bytes, + filename: str, + content_type: str | None = None, +) -> StoredFile: + """ + Store a file and return its metadata. + + Args: + content: Raw file bytes + filename: Original filename + content_type: MIME type (auto-detected if not provided) + + Returns: + StoredFile with ID and base64 data + """ + # Auto-detect content type + if content_type is None: + content_type, _ = mimetypes.guess_type(filename) + content_type = content_type or "application/octet-stream" + + # Generate file ID + file_id = _generate_file_id(content, filename) + + # Base64 encode + base64_data = base64.b64encode(content).decode("utf-8") + + # Create stored file + stored = StoredFile( + id=file_id, + filename=filename, + content_type=content_type, + size=len(content), + base64_data=base64_data, + ) + + # Store in cache + _file_cache[file_id] = stored + + # Ensure cleanup task is running + _ensure_cleanup_task() + + logger.info(f"Stored file: {file_id} ({filename}, {len(content)} bytes)") + return stored + + +def get_file(file_id: str) -> StoredFile | None: + """ + Retrieve a stored file by ID. + + Args: + file_id: The file ID returned from store_file + + Returns: + StoredFile or None if not found/expired + """ + stored = _file_cache.get(file_id) + + if stored is None: + return None + + # Check if expired + if time.time() - stored.created_at > FILE_TTL: + _file_cache.pop(file_id, None) + return None + + return stored + + +def get_file_images(file_id: str) -> list[str]: + """ + Get images for a file (the file itself for images). + + Args: + file_id: The file ID + + Returns: + List of base64-encoded images + """ + stored = get_file(file_id) + + if stored is None: + return [] + + # If it's an image file, return the base64 data + suffix = Path(stored.filename).suffix.lower() + if suffix in IMAGE_TYPES: + return [stored.base64_data] + + return [] + + +def delete_file(file_id: str) -> bool: + """ + Delete a stored file. + + Args: + file_id: The file ID + + Returns: + True if deleted, False if not found + """ + return _file_cache.pop(file_id, None) is not None + + +def list_files() -> list[dict]: + """ + List all stored files with their metadata. + + Returns: + List of file metadata dicts + """ + now = time.time() + result = [] + + for file_id, stored in list(_file_cache.items()): + # Check if expired + if now - stored.created_at > FILE_TTL: + _file_cache.pop(file_id, None) + continue + + result.append( + { + "id": stored.id, + "filename": stored.filename, + "content_type": stored.content_type, + "size": stored.size, + "created_at": stored.created_at, + "ttl_remaining": FILE_TTL - (now - stored.created_at), + } + ) + + return result + + +async def _cleanup_expired_files(): + """Background task to clean up expired files.""" + while True: + await asyncio.sleep(60) # Check every minute + + now = time.time() + expired = [ + file_id + for file_id, stored in _file_cache.items() + if now - stored.created_at > FILE_TTL + ] + + for file_id in expired: + _file_cache.pop(file_id, None) + + if expired: + logger.info(f"Cleaned up {len(expired)} expired files") + + +def _ensure_cleanup_task(): + """Ensure the cleanup background task is running.""" + global _cleanup_task + + if _cleanup_task is None or _cleanup_task.done(): + _cleanup_task = asyncio.create_task(_cleanup_expired_files()) diff --git a/runtimes/edge/utils/ggml_logging.py b/runtimes/edge/utils/ggml_logging.py new file mode 100644 index 000000000..4f5f4dab1 --- /dev/null +++ b/runtimes/edge/utils/ggml_logging.py @@ -0,0 +1,184 @@ +""" +GGML logging management utilities. + +Routes llama.cpp/GGML logs through Python's logging system using llama_log_set. +This replaces the default llama-cpp behavior of printing directly to stderr. +""" + +import ctypes +import logging +import os +from typing import Literal + +logger = logging.getLogger("ggml") + +# Environment variable to control GGML output behavior +# Options: "capture" (default), "suppress", "passthrough" +GGML_LOG_MODE_ENV = "GGML_LOG_MODE" + +# GGML log level mapping (from llama.cpp ggml.h) +# enum ggml_log_level { +# GGML_LOG_LEVEL_NONE = 0, +# GGML_LOG_LEVEL_INFO = 1, +# GGML_LOG_LEVEL_WARN = 2, +# GGML_LOG_LEVEL_ERROR = 3, +# GGML_LOG_LEVEL_DEBUG = 4, +# GGML_LOG_LEVEL_CONT = 5, // continue previous log +# }; +GGML_TO_PYTHON_LOG_LEVEL = { + 0: logging.NOTSET, # NONE + 1: logging.INFO, # INFO + 2: logging.WARNING, # WARN + 3: logging.ERROR, # ERROR + 4: logging.DEBUG, # DEBUG + 5: logging.DEBUG, # CONT (continuation) +} + +# Track state for continuation logs +_last_log_level = logging.DEBUG +_log_buffer = "" + +# Store callback reference to prevent garbage collection +_callback_ref = None + +# Messages that llama.cpp logs at ERROR level but are actually informational +# These get downgraded to DEBUG level +_FALSE_ERROR_PATTERNS = [ + "embeddings required but some input tokens were not marked as outputs", + "cannot decode batches with this context", +] + + +def get_ggml_log_mode() -> Literal["suppress", "passthrough", "capture"]: + """Get the GGML logging mode from environment variable. + + Returns: + One of: + - "capture" (default): Route GGML logs through Python's logging system + - "suppress": Silence all GGML output + - "passthrough": Let GGML output flow to stderr normally (llama-cpp default) + """ + mode = os.environ.get(GGML_LOG_MODE_ENV, "capture").lower() + if mode in ("suppress", "passthrough", "capture"): + return mode # type: ignore + logger.warning( + f"Unknown GGML_LOG_MODE '{mode}', defaulting to 'capture'. " + "Valid options: capture, suppress, passthrough" + ) + return "capture" + + +def _create_logging_callback(): + """Create a callback that routes GGML logs through Python logging.""" + from llama_cpp import llama_log_callback + + @llama_log_callback + def logging_callback( + level: int, + text: bytes, + user_data: ctypes.c_void_p, + ): + global _last_log_level, _log_buffer + + try: + msg = text.decode("utf-8", errors="replace") + except Exception: + return + + # Handle continuation logs (level 5) + if level == 5: + python_level = _last_log_level + else: + python_level = GGML_TO_PYTHON_LOG_LEVEL.get(level, logging.DEBUG) + _last_log_level = python_level + + # Buffer partial lines (GGML often sends without newlines) + _log_buffer += msg + + # Only log complete lines + while "\n" in _log_buffer: + line, _log_buffer = _log_buffer.split("\n", 1) + line = line.strip() + if line: + # Downgrade known "false error" messages to DEBUG + effective_level = python_level + if python_level >= logging.WARNING: + for pattern in _FALSE_ERROR_PATTERNS: + if pattern in line: + effective_level = logging.DEBUG + break + logger.log(effective_level, line) + + return logging_callback + + +def _create_suppressing_callback(): + """Create a callback that suppresses all GGML logs.""" + from llama_cpp import llama_log_callback + + @llama_log_callback + def suppressing_callback( + level: int, + text: bytes, + user_data: ctypes.c_void_p, + ): + pass # Silently discard all logs + + return suppressing_callback + + +def setup_ggml_logging(): + """Configure GGML logging based on GGML_LOG_MODE environment variable. + + This should be called once at startup to configure how GGML/llama.cpp + logs are handled. The behavior is controlled by the GGML_LOG_MODE + environment variable: + + - "capture" (default): Routes logs through Python's logging system + with proper log levels. Messages appear as structured logs. + - "suppress": Silences all GGML output completely. + - "passthrough": Uses llama-cpp's default behavior (prints to stderr). + + Example: + # In your server startup: + from utils.ggml_logging import setup_ggml_logging + setup_ggml_logging() + + # Or set environment variable: + # GGML_LOG_MODE=suppress python -m uvicorn ... + """ + global _callback_ref + + mode = get_ggml_log_mode() + + if mode == "passthrough": + # Don't change anything - use llama-cpp's default + logger.debug("GGML logging: passthrough mode (using llama-cpp default)") + return + + try: + from llama_cpp import llama_log_set + except ImportError: + logger.warning("llama-cpp not available, GGML logging not configured") + return + + if mode == "suppress": + _callback_ref = _create_suppressing_callback() + llama_log_set(_callback_ref, ctypes.c_void_p(0)) + logger.debug("GGML logging: suppress mode (all output silenced)") + elif mode == "capture": + _callback_ref = _create_logging_callback() + llama_log_set(_callback_ref, ctypes.c_void_p(0)) + logger.debug("GGML logging: capture mode (routing through Python logging)") + + +def flush_ggml_log_buffer(): + """Flush any remaining content in the GGML log buffer. + + Call this after operations that may leave partial log messages buffered. + """ + global _log_buffer, _last_log_level + + if _log_buffer.strip(): + logger.log(_last_log_level, _log_buffer.strip()) + _log_buffer = "" diff --git a/runtimes/edge/utils/gguf_metadata_cache.py b/runtimes/edge/utils/gguf_metadata_cache.py new file mode 100644 index 000000000..d854a5106 --- /dev/null +++ b/runtimes/edge/utils/gguf_metadata_cache.py @@ -0,0 +1,309 @@ +"""Shared GGUF metadata cache for efficient metadata extraction. + +This module provides a centralized cache for GGUF file metadata to avoid +redundant file reads. GGUF metadata reading is expensive (~4-5 seconds for +large models), so caching significantly improves performance. + +The cache stores: +- File size and context length (for context_calculator) +- Chat template (for jinja_tools) +- Special tokens (for jinja_tools) +""" + +from __future__ import annotations + +import contextlib +import logging +import os +import threading +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class GGUFMetadata: + """Cached metadata from a GGUF file.""" + + file_path: str + file_size_bytes: int + file_size_mb: float + n_ctx_train: int | None = None + chat_template: str | None = None + bos_token: str = "" + eos_token: str = "" + # Architecture params for KV cache size estimation + n_layer: int | None = None + n_head_kv: int | None = None + head_k_size: int | None = None + head_v_size: int | None = None + # Raw fields for any additional lookups + _raw_fields: dict[str, Any] = field(default_factory=dict, repr=False) + + +# Global cache: path -> GGUFMetadata +_metadata_cache: dict[str, GGUFMetadata] = {} +_cache_lock = threading.Lock() + + +def get_gguf_metadata_cached(gguf_path: str) -> GGUFMetadata: + """Get GGUF metadata, using cache if available. + + This function reads the GGUF file once and caches all commonly needed + metadata (file size, context length, chat template, special tokens). + Subsequent calls for the same path return the cached data instantly. + + Args: + gguf_path: Absolute path to the GGUF file + + Returns: + GGUFMetadata with all extracted information + + Raises: + FileNotFoundError: If GGUF file doesn't exist + """ + # Normalize path for consistent cache keys + normalized_path = os.path.normpath(os.path.abspath(gguf_path)) + + with _cache_lock: + if normalized_path in _metadata_cache: + logger.debug(f"Using cached GGUF metadata for: {normalized_path}") + return _metadata_cache[normalized_path] + + # Not in cache - read from file (outside lock to avoid blocking) + logger.info(f"Reading GGUF metadata (will be cached): {normalized_path}") + metadata = _read_gguf_metadata(normalized_path) + + with _cache_lock: + _metadata_cache[normalized_path] = metadata + + return metadata + + +def _read_gguf_metadata(gguf_path: str) -> GGUFMetadata: + """Read all metadata from a GGUF file in a single pass. + + This is an internal function that performs the actual file reading. + Use get_gguf_metadata_cached() for cached access. + """ + # Reject paths with traversal sequences (check path segments, not raw string) + from pathlib import PurePosixPath, PureWindowsPath + if ".." in PurePosixPath(gguf_path).parts or ".." in PureWindowsPath(gguf_path).parts: + raise ValueError(f"Invalid GGUF path: {gguf_path}") + + if not os.path.exists(gguf_path): + raise FileNotFoundError(f"GGUF file not found: {gguf_path}") + + file_size = os.path.getsize(gguf_path) + + metadata = GGUFMetadata( + file_path=gguf_path, + file_size_bytes=file_size, + file_size_mb=file_size / (1024 * 1024), + ) + + try: + import gc + + from gguf import GGUFReader + + # Temporarily disable GC during GGUF parsing to avoid segfault + # on Python 3.13 aarch64 (GC during gguf_reader causes crash) + gc_was_enabled = gc.isenabled() + gc.disable() + reader = None + try: + reader = GGUFReader(gguf_path) + except (ValueError, KeyError) as e: + # Some GGUF files use newer quantization types (e.g. Q6_K_XL = type 39) + # that the Python gguf library doesn't support yet. The error occurs + # during tensor parsing, but metadata fields are already parsed by then. + # Monkey-patch to skip tensor building and retry. + logger.warning( + f"GGUF tensor parsing failed ({e}), retrying with metadata-only read" + ) + try: + # Use lock to prevent concurrent monkey-patch conflicts + with _cache_lock: + _orig_build_tensors = GGUFReader._build_tensors + GGUFReader._build_tensors = lambda self, *a, **kw: None + try: + reader = GGUFReader(gguf_path) + finally: + GGUFReader._build_tensors = _orig_build_tensors + if reader is not None: + reader.tensors = [] + except Exception as inner_e: + logger.warning(f"Metadata-only GGUF read also failed: {inner_e}") + finally: + if gc_was_enabled: + gc.enable() + + if reader is None: + return metadata + + # Extract all needed metadata in a single pass through fields + bos_id = None + eos_id = None + tokens_data = None + + for key, field in reader.fields.items(): + # Store raw fields for debugging + metadata._raw_fields[key] = field + + # Context length fields + context_field_names = ["context_length", "n_ctx_train", "n_ctx"] + if any(target in key for target in context_field_names) and field.data: + try: + n_ctx_train = field.parts[field.data[0]] + if n_ctx_train: + metadata.n_ctx_train = int(n_ctx_train) + logger.debug( + f"Found context size in field '{key}': {n_ctx_train}" + ) + except (IndexError, ValueError, TypeError) as e: + logger.debug("Could not parse context size from field %s: %s", key, e) + + # Architecture params for KV cache estimation + # Keys are prefixed by architecture (e.g., qwen3.block_count), + # so we match by suffix. + _arch_field_map = { + ".block_count": "n_layer", + ".attention.head_count_kv": "n_head_kv", + ".attention.key_length": "head_k_size", + ".attention.value_length": "head_v_size", + } + for suffix, attr in _arch_field_map.items(): + if key.endswith(suffix) and field.data: + try: + val = int(field.parts[field.data[0]]) + setattr(metadata, attr, val) + except (IndexError, ValueError, TypeError) as e: + logger.debug("Could not parse GGUF field %s: %s", key, e) + + # Chat template + if key == "tokenizer.chat_template": + if hasattr(field, "parts") and field.parts: + # Use only the last part which contains the actual string data + # GGUF field.parts structure for strings: + # parts[0]: field name length (8 bytes) + # parts[1]: field name (bytes) + # parts[2]: type indicator (4 bytes) + # parts[3]: string length (8 bytes) + # parts[-1]: the actual string data + try: + template_bytes = bytes(field.parts[-1]) + metadata.chat_template = template_bytes.decode("utf-8") + logger.debug( + f"Found chat template ({len(metadata.chat_template)} chars)" + ) + except (IndexError, UnicodeDecodeError) as e: + logger.warning(f"Failed to decode chat template: {e}") + elif hasattr(field, "data"): + # Older format fallback + try: + metadata.chat_template = bytes(field.data).decode("utf-8") + except UnicodeDecodeError as e: + logger.warning( + f"Failed to decode chat template (fallback): {e}" + ) + + # BOS token ID + if key == "tokenizer.ggml.bos_token_id": + if hasattr(field, "parts") and field.parts: + with contextlib.suppress(IndexError, ValueError, TypeError): + bos_id = int(field.parts[0][0]) + elif hasattr(field, "data"): + with contextlib.suppress(IndexError, ValueError, TypeError): + bos_id = int(field.data[0]) + + # EOS token ID + if key == "tokenizer.ggml.eos_token_id": + if hasattr(field, "parts") and field.parts: + with contextlib.suppress(IndexError, ValueError, TypeError): + eos_id = int(field.parts[0][0]) + elif hasattr(field, "data"): + with contextlib.suppress(IndexError, ValueError, TypeError): + eos_id = int(field.data[0]) + + # Tokens array (for resolving BOS/EOS IDs to strings) + if key == "tokenizer.ggml.tokens": + if hasattr(field, "parts"): + tokens_data = field.parts + elif hasattr(field, "data"): + tokens_data = field.data + + # Resolve token IDs to strings + if tokens_data is not None: + if bos_id is not None and bos_id < len(tokens_data): + try: + token_bytes = tokens_data[bos_id] + if isinstance(token_bytes, (bytes, bytearray)): + metadata.bos_token = token_bytes.decode( + "utf-8", errors="replace" + ) + elif isinstance(token_bytes, str): + metadata.bos_token = token_bytes + except (IndexError, UnicodeDecodeError): + # Non-critical: leave bos_token as None if decode fails + logger.debug("Failed to decode BOS token (id=%s, tokens=%d)", bos_id, len(tokens_data)) + + if eos_id is not None and eos_id < len(tokens_data): + try: + token_bytes = tokens_data[eos_id] + if isinstance(token_bytes, (bytes, bytearray)): + metadata.eos_token = token_bytes.decode( + "utf-8", errors="replace" + ) + elif isinstance(token_bytes, str): + metadata.eos_token = token_bytes + except (IndexError, UnicodeDecodeError): + # Non-critical: leave eos_token as None if decode fails + logger.debug("Failed to decode EOS token (id=%s, tokens=%d)", eos_id, len(tokens_data)) + + logger.debug( + f"GGUF metadata extracted: n_ctx={metadata.n_ctx_train}, " + f"template={len(metadata.chat_template or '')} chars, " + f"bos='{metadata.bos_token}', eos='{metadata.eos_token}'" + ) + + except ImportError: + logger.warning("gguf package not installed, limited metadata available") + except Exception as e: + logger.warning(f"Error reading GGUF metadata: {e}") + + return metadata + + +def clear_metadata_cache(gguf_path: str | None = None) -> None: + """Clear the GGUF metadata cache. + + Args: + gguf_path: If provided, only clear cache for this specific path. + If None, clear the entire cache. + """ + global _metadata_cache + + with _cache_lock: + if gguf_path: + normalized_path = os.path.normpath(os.path.abspath(gguf_path)) + if normalized_path in _metadata_cache: + del _metadata_cache[normalized_path] + logger.debug(f"Cleared GGUF metadata cache for: {normalized_path}") + else: + _metadata_cache = {} + logger.debug("Cleared all GGUF metadata cache") + + +def get_cache_stats() -> dict: + """Get statistics about the metadata cache. + + Returns: + Dict with cache statistics (entry count, paths cached) + """ + with _cache_lock: + return { + "entry_count": len(_metadata_cache), + "cached_paths": list(_metadata_cache.keys()), + } diff --git a/runtimes/edge/utils/gpu_allocator.py b/runtimes/edge/utils/gpu_allocator.py new file mode 100644 index 000000000..97ed553b4 --- /dev/null +++ b/runtimes/edge/utils/gpu_allocator.py @@ -0,0 +1,349 @@ +"""GPU allocation for multi-model, multi-GPU GGUF inference. + +Selects the optimal GPU for each model load, preferring single-GPU placement +(split_mode=NONE) over multi-GPU splitting. Prevents OOM crashes by estimating +VRAM requirements before loading. + +llama.cpp's default split_mode=LAYER distributes every model across ALL visible +Vulkan/CUDA devices proportionally. This is problematic for multi-model scenarios: +a second model may OOM on a weaker GPU that's already partially filled. + +This module queries actual free VRAM per device via torch.cuda.mem_get_info() +and routes each model to the single GPU with the most headroom. Multi-GPU split +is only used as a fallback when no single GPU can fit the model. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from utils.context_calculator import compute_kv_bytes_per_token + +logger = logging.getLogger(__name__) + +# Split mode constants (match llama.cpp enum llama_split_mode) +SPLIT_MODE_NONE = 0 # Entire model on main_gpu +SPLIT_MODE_LAYER = 1 # Split layers across GPUs +SPLIT_MODE_ROW = 2 # Split rows within layers + + +class InsufficientVRAMError(RuntimeError): + """Raised when no GPU configuration has enough VRAM for the model. + + The user-facing message is intentionally generic to avoid exposing + internal GPU inventory details. Detailed diagnostics are stored in + ``self.gpu_details`` for server-side logging. + """ + + def __init__(self, message: str, gpu_details: str = ""): + super().__init__(message) + self.gpu_details = gpu_details + + +@dataclass +class GPUDevice: + """Information about a single GPU device.""" + + index: int + name: str + total_vram: int # bytes + free_vram: int # bytes + + +@dataclass +class GPUAllocation: + """Result of GPU allocation for a model load.""" + + gpu_index: int # Primary GPU (-1 if CPU) + split_mode: int # SPLIT_MODE_* constant + main_gpu: int # main_gpu param for llama.cpp + tensor_split: list[float] | None # Proportions for multi-GPU split + estimated_vram: int # Estimated VRAM usage in bytes + total_free_vram: int # Combined free VRAM across viable GPUs (bytes) + + +def enumerate_gpus() -> list[GPUDevice]: + """Enumerate available CUDA/Vulkan GPUs with free VRAM. + + Uses torch.cuda APIs which reflect the same physical devices as Vulkan + (Vulkan0 = cuda:0, Vulkan1 = cuda:1, etc.). + + Returns: + List of GPUDevice, empty if no CUDA GPUs or torch unavailable. + """ + try: + import torch + + if not torch.cuda.is_available(): + return [] + + devices = [] + for i in range(torch.cuda.device_count()): + free, total = torch.cuda.mem_get_info(i) + name = torch.cuda.get_device_name(i) + devices.append( + GPUDevice(index=i, name=name, total_vram=total, free_vram=free) + ) + logger.debug( + f"GPU {i} ({name}): {free / (1024**3):.2f} GiB free / " + f"{total / (1024**3):.2f} GiB total" + ) + + return devices + except ImportError: + logger.debug("PyTorch not available, skipping GPU enumeration") + return [] + except Exception as e: + logger.warning(f"Error enumerating GPUs: {e}") + return [] + + +def estimate_model_vram( + model_size_bytes: int, + n_ctx: int, + n_gpu_layers: int, + total_layers: int | None = None, + n_layer: int | None = None, + n_head_kv: int | None = None, + head_k_size: int | None = None, + head_v_size: int | None = None, +) -> int: + """Estimate total VRAM needed for a GGUF model. + + Args: + model_size_bytes: GGUF file size in bytes (approximates GPU weight size). + n_ctx: Context window size (tokens). + n_gpu_layers: Number of layers to offload (-1 or 999 = all). + total_layers: Total layer count in the model (for partial offload). + n_layer: Number of transformer layers (from GGUF metadata). + n_head_kv: Number of KV attention heads (from GGUF metadata). + head_k_size: Key head dimension (from GGUF metadata). + head_v_size: Value head dimension (from GGUF metadata). + + Returns: + Estimated VRAM in bytes. + """ + # Model weights on GPU + if n_gpu_layers == 0: + # CPU only + return 0 + + if total_layers and n_gpu_layers > 0 and n_gpu_layers < 999: + # Partial offload: scale weight size proportionally + gpu_weight_bytes = int(model_size_bytes * (n_gpu_layers / total_layers)) + else: + # Full offload + gpu_weight_bytes = model_size_bytes + + # KV cache + has_arch = all( + v is not None for v in [n_layer, n_head_kv, head_k_size, head_v_size] + ) + if has_arch: + kv_bytes_per_token = compute_kv_bytes_per_token( + n_layer, n_head_kv, head_k_size, head_v_size + ) + else: + # Conservative fallback: 256 KB/token (matches context_calculator) + kv_bytes_per_token = 256 * 1024 + + kv_cache_bytes = kv_bytes_per_token * n_ctx + + # Total with 20% overhead for compute buffers and scratch space + total = int((gpu_weight_bytes + kv_cache_bytes) * 1.2) + + logger.debug( + f"VRAM estimate: weights={gpu_weight_bytes / (1024**3):.2f} GiB, " + f"KV cache={kv_cache_bytes / (1024**3):.2f} GiB, " + f"total (with overhead)={total / (1024**3):.2f} GiB" + ) + + return total + + +def allocate_gpu(estimated_vram: int, gpus: list[GPUDevice]) -> GPUAllocation: + """Select the optimal GPU(s) for a model load. + + Strategy: + 1. Try single-GPU placement on the GPU with most free VRAM (split_mode=NONE). + 2. Fall back to multi-GPU layer splitting if no single GPU fits. + 3. Raise InsufficientVRAMError if even combined VRAM is insufficient. + + Args: + estimated_vram: Estimated VRAM needed (from estimate_model_vram). + gpus: Available GPUs (from enumerate_gpus). + + Returns: + GPUAllocation with parameters to pass to llama.cpp. + + Raises: + InsufficientVRAMError: If no GPU configuration can fit the model. + """ + if not gpus: + raise InsufficientVRAMError("No GPUs available") + + # 10% safety margin for driver overhead and estimation error + required = int(estimated_vram * 1.1) + + # Sort by free VRAM descending + sorted_gpus = sorted(gpus, key=lambda g: g.free_vram, reverse=True) + + # Strategy 1: Single GPU placement + best = sorted_gpus[0] + if best.free_vram >= required: + logger.info( + f"Allocating model to GPU {best.index} ({best.name}): " + f"{estimated_vram / (1024**3):.2f} GiB needed, " + f"{best.free_vram / (1024**3):.2f} GiB free" + ) + return GPUAllocation( + gpu_index=best.index, + split_mode=SPLIT_MODE_NONE, + main_gpu=best.index, + tensor_split=None, + estimated_vram=estimated_vram, + total_free_vram=best.free_vram, + ) + + # Strategy 2: Multi-GPU split + # Exclude GPUs with too little free VRAM to carry their share of + # non-splittable overhead (compute buffers, scratch space). A GPU + # needs at least 512 MiB free to participate usefully in a split. + min_participation = 512 * 1024**2 # 512 MiB + # Per-GPU fixed overhead for compute buffers and scratch space that + # llama.cpp allocates on each device regardless of split fraction. + per_gpu_overhead = 256 * 1024**2 # 256 MiB + viable_gpus = [g for g in gpus if g.free_vram >= min_participation] + + # Iteratively prune GPUs whose free VRAM cannot cover their + # proportional share of the model plus per-device fixed overhead. + # Use estimated_vram (not required) for per-device checks — the 10% + # safety margin is already enforced globally via total_free >= required. + pruned = True + while pruned and len(viable_gpus) > 1: + pruned = False + total_free = sum(g.free_vram for g in viable_gpus) + if total_free < required: + break + for g in viable_gpus: + share = (g.free_vram / total_free) * estimated_vram + if share + per_gpu_overhead > g.free_vram: + viable_gpus = [v for v in viable_gpus if v.index != g.index] + pruned = True + break + + total_free = sum(g.free_vram for g in viable_gpus) + + if total_free >= required and len(viable_gpus) > 1: + # Build split proportions only for viable GPUs, zero out excluded ones + by_index = sorted(gpus, key=lambda g: g.index) + viable_indices = {g.index for g in viable_gpus} + raw_split = [ + float(g.free_vram) if g.index in viable_indices else 0.0 for g in by_index + ] + total = sum(raw_split) + tensor_split = [v / total for v in raw_split] + + gpu_desc = ", ".join( + f"GPU {g.index} ({g.name}): {g.free_vram / (1024**3):.2f} GiB free" + for g in viable_gpus + ) + logger.info( + f"Model requires multi-GPU split: " + f"{estimated_vram / (1024**3):.2f} GiB needed, " + f"no single GPU has enough. Splitting across: {gpu_desc}" + ) + return GPUAllocation( + gpu_index=sorted_gpus[0].index, + split_mode=SPLIT_MODE_LAYER, + main_gpu=sorted_gpus[0].index, + tensor_split=tensor_split, + estimated_vram=estimated_vram, + total_free_vram=total_free, + ) + + # Strategy 3: Insufficient VRAM + gpu_desc = "\n".join( + f" GPU {g.index} ({g.name}): {g.free_vram / (1024**3):.2f} GiB free / " + f"{g.total_vram / (1024**3):.2f} GiB total" + for g in sorted_gpus + ) + details = ( + f"Estimated VRAM needed: {estimated_vram / (1024**3):.2f} GiB\n" + f"Available GPUs:\n{gpu_desc}\n" + f"Combined free: {total_free / (1024**3):.2f} GiB" + ) + logger.error(f"Insufficient GPU memory to load model.\n{details}") + raise InsufficientVRAMError( + "Insufficient GPU memory to load model. " + "Consider unloading other models, reducing context size, " + "or using a smaller quantization.", + gpu_details=details, + ) + + +def get_llama_gpu_params( + model_size_bytes: int, + n_ctx: int, + n_gpu_layers: int, + total_layers: int | None = None, + n_layer: int | None = None, + n_head_kv: int | None = None, + head_k_size: int | None = None, + head_v_size: int | None = None, +) -> dict: + """Get GPU parameters for a Llama() constructor call. + + Convenience function that enumerates GPUs, estimates VRAM, and allocates. + Returns a dict of keyword arguments to pass to Llama(). + + Args: + model_size_bytes: GGUF file size in bytes. + n_ctx: Context window size. + n_gpu_layers: Number of layers to offload. + total_layers: Total layers in the model. + n_layer: Transformer layer count (GGUF metadata). + n_head_kv: KV head count (GGUF metadata). + head_k_size: Key head dimension (GGUF metadata). + head_v_size: Value head dimension (GGUF metadata). + + Returns: + Dict with keys: main_gpu, split_mode, tensor_split, gpu_index. + Empty dict if CUDA is not available (caller should use defaults). + + Raises: + InsufficientVRAMError: If no GPU can fit the model. + """ + if n_gpu_layers == 0: + # CPU-only, no GPU allocation needed + return {} + + gpus = enumerate_gpus() + if not gpus: + # No CUDA GPUs - llama.cpp will use Vulkan/Metal/CPU on its own + return {} + + estimated_vram = estimate_model_vram( + model_size_bytes=model_size_bytes, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers, + total_layers=total_layers, + n_layer=n_layer, + n_head_kv=n_head_kv, + head_k_size=head_k_size, + head_v_size=head_v_size, + ) + + allocation = allocate_gpu(estimated_vram, gpus) + + result = { + "main_gpu": allocation.main_gpu, + "split_mode": allocation.split_mode, + "gpu_index": allocation.gpu_index, + "total_free_vram": allocation.total_free_vram, + } + if allocation.tensor_split is not None: + result["tensor_split"] = allocation.tensor_split + + return result diff --git a/runtimes/edge/utils/history_compressor.py b/runtimes/edge/utils/history_compressor.py new file mode 100644 index 000000000..979548210 --- /dev/null +++ b/runtimes/edge/utils/history_compressor.py @@ -0,0 +1,259 @@ +"""History compression utilities. + +Provides lossless and near-lossless compression techniques for +conversation history to reduce token usage before truncation. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .token_counter import TokenCounter + +logger = logging.getLogger(__name__) + + +class HistoryCompressor: + """Compresses conversation history to reduce token usage. + + Applies multiple compression techniques that preserve meaning + while reducing token count: + - Whitespace normalization + - Tool result truncation + - Code block compression + - Repetition removal + """ + + # Number of recent messages to leave untouched + PRESERVE_RECENT = 4 + + # Maximum tokens for old tool results + MAX_TOOL_RESULT_TOKENS = 200 + + # Maximum lines for code blocks in old messages + MAX_CODE_BLOCK_LINES = 20 + + def __init__(self, token_counter: TokenCounter | None = None): + """Initialize history compressor. + + Args: + token_counter: Optional TokenCounter for token-based compression. + If not provided, some compression features are limited. + """ + self._counter = token_counter + + def compress( + self, + messages: list[dict], + preserve_recent: int | None = None, + ) -> list[dict]: + """Apply all compression techniques to messages. + + Compresses older messages while preserving the most recent + messages untouched for immediate context. + + Args: + messages: List of chat messages. + preserve_recent: Number of recent messages to preserve (default: 4). + + Returns: + Compressed messages (deep copy, original unchanged). + """ + if not messages: + return messages + + # Use explicit None check to allow preserve_recent=0 + if preserve_recent is None: + preserve_recent = self.PRESERVE_RECENT + + # Deep copy to avoid modifying original + # Use JSON for Pydantic-safe deep copy + messages = json.loads(json.dumps(messages, default=str)) + + # Split into old (to compress) and recent (to preserve) + if len(messages) <= preserve_recent: + # Not enough to compress, just normalize whitespace + return self._normalize_all_whitespace(messages) + + old_msgs = messages[:-preserve_recent] + recent_msgs = messages[-preserve_recent:] + + # Apply compression pipeline to old messages + old_msgs = self._normalize_all_whitespace(old_msgs) + old_msgs = self._compress_tool_results(old_msgs) + old_msgs = self._compress_code_blocks(old_msgs) + old_msgs = self._remove_repetitions(old_msgs) + + return old_msgs + recent_msgs + + def _normalize_all_whitespace(self, messages: list[dict]) -> list[dict]: + """Normalize whitespace in all message contents. + + Args: + messages: List of messages. + + Returns: + Messages with normalized whitespace. + """ + for msg in messages: + content = msg.get("content") + if content and isinstance(content, str): + msg["content"] = self._normalize_whitespace(content) + return messages + + def _normalize_whitespace(self, content: str) -> str: + """Collapse excessive whitespace. + + Args: + content: Text content to normalize. + + Returns: + Normalized content. + """ + # Collapse multiple newlines to max 2 + content = re.sub(r"\n{3,}", "\n\n", content) + # Collapse multiple spaces to single (but preserve indentation at line start) + content = re.sub(r"(?<=\S) +", " ", content) + return content.strip() + + def _compress_tool_results(self, messages: list[dict]) -> list[dict]: + """Truncate verbose tool call results. + + Tool results (file contents, API responses) are often very long. + After the assistant has processed them, the full content is + less important for context. + + Args: + messages: List of messages. + + Returns: + Messages with compressed tool results. + """ + for msg in messages: + if msg.get("role") == "tool": + content = msg.get("content", "") + if not content: + continue + + # Check token count if counter available + if self._counter: + token_count = self._counter.count_tokens(content) + if token_count > self.MAX_TOOL_RESULT_TOKENS: + truncated = self._counter.truncate_to_tokens( + content, self.MAX_TOOL_RESULT_TOKENS + ) + msg["content"] = truncated + "\n\n[... result truncated ...]" + else: + # Fallback: use character count (rough estimate: 4 chars per token) + max_chars = self.MAX_TOOL_RESULT_TOKENS * 4 + if len(content) > max_chars: + msg["content"] = ( + content[:max_chars] + "\n\n[... result truncated ...]" + ) + + return messages + + def _compress_code_blocks(self, messages: list[dict]) -> list[dict]: + """Compress large code blocks in messages. + + Large code blocks in assistant responses can be condensed + after they've been seen. + + Args: + messages: List of messages. + + Returns: + Messages with compressed code blocks. + """ + code_block_pattern = re.compile( + r"```(\w*)\n(.*?)```", + re.DOTALL, + ) + + for msg in messages: + if msg.get("role") != "assistant": + continue + + content = msg.get("content", "") + if not content or "```" not in content: + continue + + def compress_block(match: re.Match) -> str: + language = match.group(1) or "code" + code = match.group(2) + lines = code.split("\n") + + if len(lines) <= self.MAX_CODE_BLOCK_LINES: + return match.group(0) # Keep original + + # Create summary + first_lines = "\n".join(lines[:5]) + summary = ( + f"```{language}\n" + f"{first_lines}\n" + f"# ... ({len(lines)} lines total) ...\n" + f"```" + ) + return summary + + msg["content"] = code_block_pattern.sub(compress_block, content) + + return messages + + def _remove_repetitions(self, messages: list[dict]) -> list[dict]: + """Remove duplicate or near-duplicate content. + + If the same content appears multiple times in history, + keep only the most recent occurrence. + + Args: + messages: List of messages. + + Returns: + Messages with repetitions removed. + """ + seen_hashes: set[str] = set() + result: list[dict] = [] + + # Process in reverse to keep most recent + for msg in reversed(messages): + content = msg.get("content", "") + + # Skip empty or very short messages + if not content or len(content) < 50: + result.append(msg) + continue + + # Create hash of normalized content + normalized = self._normalize_for_hash(content) + content_hash = hashlib.md5(normalized.encode()).hexdigest() + + if content_hash in seen_hashes: + logger.debug(f"Removing duplicate message: {content[:50]}...") + continue + + seen_hashes.add(content_hash) + result.append(msg) + + # Restore original order + return list(reversed(result)) + + def _normalize_for_hash(self, content: str) -> str: + """Normalize content for duplicate detection. + + Args: + content: Text content. + + Returns: + Normalized content for hashing. + """ + # Lowercase, collapse whitespace, remove punctuation + normalized = content.lower() + normalized = re.sub(r"\s+", " ", normalized) + normalized = re.sub(r"[^\w\s]", "", normalized) + return normalized.strip() diff --git a/runtimes/edge/utils/jinja_tools.py b/runtimes/edge/utils/jinja_tools.py new file mode 100644 index 000000000..0990cec7f --- /dev/null +++ b/runtimes/edge/utils/jinja_tools.py @@ -0,0 +1,192 @@ +""" +Jinja2 template utilities for tool-aware chat template rendering. + +This module provides functions to extract chat templates from GGUF files +and render them with tool definitions using Python's Jinja2. + +Uses the shared GGUF metadata cache to avoid redundant file reads when +extracting chat templates and special tokens. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from jinja2 import TemplateError, Undefined +from jinja2.sandbox import ImmutableSandboxedEnvironment, SandboxedEnvironment +from jinja2.utils import Namespace + +from utils.gguf_metadata_cache import get_gguf_metadata_cached + +logger = logging.getLogger(__name__) + + +class RaiseExceptionUndefined(Undefined): + """Jinja2 Undefined that raises an exception when used. + + Some chat templates use `raise_exception` to signal errors. + This class provides that functionality. + """ + + def __str__(self) -> str: + raise TemplateError(f"Undefined variable: {self._undefined_name}") + + def __iter__(self): + raise TemplateError(f"Undefined variable: {self._undefined_name}") + + def __bool__(self): + return False + + +def _raise_exception(message: str) -> None: + """Template function to raise an exception.""" + raise TemplateError(message) + + +def _tojson(value: Any, indent: int | None = None) -> str: + """Template filter to convert value to JSON string.""" + return json.dumps(value, indent=indent, ensure_ascii=False) + + +def get_chat_template_from_gguf(model_path: str) -> str | None: + """Extract chat_template from GGUF file metadata. + + Uses the shared GGUF metadata cache to avoid redundant file reads. + The cache is populated once per file and reused by all modules. + + Args: + model_path: Path to the GGUF model file. + + Returns: + The chat template string, or None if not found. + """ + try: + cached = get_gguf_metadata_cached(model_path) + return cached.chat_template + except FileNotFoundError: + logger.debug(f"GGUF file not found: {model_path}") + return None + except Exception as e: + logger.debug(f"Failed to extract chat template from {model_path}: {e}") + return None + + +def get_special_tokens_from_gguf(model_path: str) -> dict[str, str]: + """Extract BOS and EOS tokens from GGUF file metadata. + + Uses the shared GGUF metadata cache to avoid redundant file reads. + The cache is populated once per file and reused by all modules. + + Args: + model_path: Path to the GGUF model file. + + Returns: + Dictionary with 'bos_token' and 'eos_token' keys. + Values default to empty strings if not found. + """ + try: + cached = get_gguf_metadata_cached(model_path) + return { + "bos_token": cached.bos_token, + "eos_token": cached.eos_token, + } + except FileNotFoundError: + logger.debug(f"GGUF file not found: {model_path}") + return {"bos_token": "", "eos_token": ""} + except Exception as e: + logger.debug(f"Failed to extract special tokens from {model_path}: {e}") + return {"bos_token": "", "eos_token": ""} + + +def supports_native_tools(template: str) -> bool: + """Check if a chat template has native tool support. + + A template supports tools if it references the 'tools' variable. + + Args: + template: The Jinja2 chat template string. + + Returns: + True if the template references tools, False otherwise. + """ + # Simple heuristic: check if 'tools' appears in the template + # This catches patterns like {% if tools %}, {{ tools }}, etc. + return "tools" in template + + +def create_jinja_environment() -> SandboxedEnvironment: + """Create a sandboxed Jinja2 environment configured for chat templates. + + Uses SandboxedEnvironment to prevent arbitrary code execution from + potentially malicious templates in GGUF files. + + Returns: + Configured Jinja2 SandboxedEnvironment. + """ + env = ImmutableSandboxedEnvironment( + # Use undefined that returns False for boolean checks + undefined=RaiseExceptionUndefined, + # Keep trailing newlines + keep_trailing_newline=True, + # Auto-escape disabled (we're not generating HTML) + autoescape=False, + ) + + # Add template functions used by various chat templates + env.globals["raise_exception"] = _raise_exception + # Use Jinja2's built-in Namespace which properly handles attribute assignment + env.globals["namespace"] = Namespace + + # Add filters + env.filters["tojson"] = _tojson + + return env + + +def render_chat_with_tools( + template: str, + messages: list[dict], + tools: list[dict] | None = None, + add_generation_prompt: bool = True, + bos_token: str = "", + eos_token: str = "", +) -> str: + """Render a chat template with Jinja2 including tool definitions. + + This function mimics what llama.cpp's Jinja-based template rendering does, + allowing us to pass tools to models that have tool-aware templates. + + Args: + template: The Jinja2 chat template string. + messages: List of chat messages (role, content dicts). + tools: Optional list of tool definitions (OpenAI format). + add_generation_prompt: Whether to add the assistant prompt at the end. + bos_token: Beginning of sequence token. + eos_token: End of sequence token. + + Returns: + The rendered prompt string. + + Raises: + TemplateError: If template rendering fails. + """ + env = create_jinja_environment() + + try: + template_obj = env.from_string(template) + except Exception as e: + raise TemplateError(f"Failed to parse chat template: {e}") from e + + try: + rendered = template_obj.render( + messages=messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + bos_token=bos_token, + eos_token=eos_token, + ) + return rendered + except Exception as e: + raise TemplateError(f"Failed to render chat template: {e}") from e diff --git a/runtimes/edge/utils/kv_cache_manager.py b/runtimes/edge/utils/kv_cache_manager.py new file mode 100644 index 000000000..f0d94193e --- /dev/null +++ b/runtimes/edge/utils/kv_cache_manager.py @@ -0,0 +1,704 @@ +"""KV Cache Manager — server-side multi-agent KV cache with tiered storage. + +Manages named KV cache slots so multiple agents can share a model without +evicting each other's cached prefixes. Supports segment-level validation +(system prompt, tools, history turns) so partial hits are possible when +only part of the payload has changed. + +Tiers: vram (in llama.cpp context) → ram (serialized bytes) → disk → evict +""" + +from __future__ import annotations + +import asyncio +import contextlib +import hashlib +import json +import logging +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# ── Segment Hashing ───────────────────────────────────────────────────────── + + +def hash_segment(content: str) -> str: + """Deterministic hash of a content segment.""" + return hashlib.sha256(content.encode("utf-8")).hexdigest()[:16] + + +def hash_messages_segments(messages: list[dict], tools: list[dict] | None = None) -> list[dict]: + """Break messages + tools into hashable segments. + + Returns a list of segment dicts: + [{"type": "system", "hash": "...", "content": "..."}, + {"type": "tools", "hash": "...", "content": "..."}, + {"type": "turn", "hash": "...", "content": "...", "index": 0}, ...] + + The content is the raw string used for hashing (for recomputation on miss). + """ + segments: list[dict] = [] + + # Extract system prompt + system_parts = [] + non_system: list[dict] = [] + for msg in messages: + if msg.get("role") == "system": + system_parts.append(msg.get("content", "")) + else: + non_system.append(msg) + + if system_parts: + system_content = "\n".join(system_parts) + segments.append({ + "type": "system", + "hash": hash_segment(system_content), + "content": system_content, + }) + + # Tools as a segment (canonical order for deterministic hashing) + if tools: + sorted_tools = sorted( + tools, + key=lambda t: ( + t.get("type", ""), + t.get("function", {}).get("name", ""), + ), + ) + tools_content = json.dumps(sorted_tools, sort_keys=True, separators=(",", ":")) + segments.append({ + "type": "tools", + "hash": hash_segment(tools_content), + "content": tools_content, + }) + + # Conversation turns (pair user+assistant as one segment) + turn_idx = 0 + i = 0 + while i < len(non_system): + turn_parts = [] + # Collect one turn: user + optional assistant response + msg = non_system[i] + turn_parts.append(f"{msg.get('role', '')}:{msg.get('content', '')}") + i += 1 + # If next is assistant, include it in same turn + if i < len(non_system) and non_system[i].get("role") == "assistant": + turn_parts.append(f"assistant:{non_system[i].get('content', '')}") + i += 1 + turn_content = "|".join(turn_parts) + segments.append({ + "type": "turn", + "hash": hash_segment(turn_content), + "content": turn_content, + "index": turn_idx, + }) + turn_idx += 1 + + return segments + + +def compare_segments( + cached_segments: list[dict], incoming_segments: list[dict] +) -> tuple[int, str | None]: + """Compare cached vs incoming segments. Returns (match_count, invalidated_at). + + match_count: how many leading segments match + invalidated_at: type of first mismatched segment (None if all match) + """ + for i, (cached, incoming) in enumerate(zip(cached_segments, incoming_segments, strict=False)): + if cached["hash"] != incoming["hash"]: + return i, cached.get("type", "unknown") + + # All compared segments match + if len(cached_segments) <= len(incoming_segments): + return len(cached_segments), None + else: + # Cached has more segments than incoming (history truncated?) + return len(incoming_segments), "truncated" + + +# ── Cache Entry ────────────────────────────────────────────────────────────── + + +@dataclass +class CacheEntry: + """A cached KV state with segment metadata.""" + cache_key: str + model_id: str + segments: list[dict] # segment hashes for validation + content_hash: str # hash of all segments combined + token_count: int # number of tokens in the cached prefix + created_at: float = field(default_factory=time.time) + last_used: float = field(default_factory=time.time) + hit_count: int = 0 + pinned: bool = False + ttl: float | None = None # seconds, None = no expiry + tier: str = "ram" # "vram" | "ram" | "disk" + seq_id: int = -1 # llama.cpp sequence ID if in vram + # Serialized KV state (when in ram tier) + kv_data: bytes = b"" + # Disk path (when in disk tier) + disk_path: str | None = None + size_bytes: int = 0 + + @property + def is_expired(self) -> bool: + if self.ttl is None: + return False + return time.time() - self.last_used > self.ttl + + def touch(self) -> None: + self.last_used = time.time() + self.hit_count += 1 + + def to_dict(self) -> dict: + return { + "cache_key": self.cache_key, + "model_id": self.model_id, + "segments": [{"type": s["type"], "hash": s["hash"]} for s in self.segments], + "content_hash": self.content_hash, + "token_count": self.token_count, + "tier": self.tier, + "size_bytes": self.size_bytes, + "hit_count": self.hit_count, + "pinned": self.pinned, + "last_used": self.last_used, + "created_at": self.created_at, + "is_expired": self.is_expired, + } + + +# ── KV Cache Manager ──────────────────────────────────────────────────────── + + +def _generate_cache_key() -> str: + """Generate a unique cache key (24 hex chars = 96 bits of entropy).""" + return hashlib.sha256(os.urandom(32)).hexdigest()[:24] + + +@dataclass +class CacheBudget: + """Budget limits for each tier.""" + max_vram_entries: int = 8 # max sequences in llama.cpp context + max_ram_bytes: int = 2 * 1024 * 1024 * 1024 # 2GB + max_disk_bytes: int = 10 * 1024 * 1024 * 1024 # 10GB + default_ttl: float = 1800.0 # 30 minutes + + +class KVCacheManager: + """Manages KV cache entries with tiered storage and GC. + + Lifecycle: + 1. prepare() — tokenize + forward pass a prefix, save KV state + 2. lookup() — find cache entry by key, validate segments + 3. restore() — load KV state back into model context + 4. save_after_generation() — update cache with new conversation state + """ + + def __init__( + self, + cache_dir: Path | None = None, + budget: CacheBudget | None = None, + ): + self._entries: dict[str, CacheEntry] = {} # cache_key → entry + self._content_index: dict[str, str] = {} # content_hash → cache_key (dedup) + self._budget = budget or CacheBudget() + self._cache_dir = cache_dir or Path.home() / ".llamafarm" / "cache" / "kv" + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._lock = asyncio.Lock() + # Stats + self._total_hits = 0 + self._total_misses = 0 + self._total_partial_hits = 0 + + # ── Core Operations ────────────────────────────────────────────────── + + async def prepare( + self, + model_id: str, + messages: list[dict], + tools: list[dict] | None = None, + pinned: bool = False, + ttl: float | None = None, + model: Any = None, # Llama instance — if provided, does real KV serialization + ) -> CacheEntry: + """Pre-compute and serialize KV cache for a message prefix. + + If `model` is provided: tokenizes the messages through the model's chat + template, runs a forward pass to build KV state, and serializes it. + Future requests with this cache_key skip all prompt processing. + + If `model` is None: indexes segments for validation only. Real KV state + is serialized lazily after the first completion via save_after_generation(). + """ + segments = hash_messages_segments(messages, tools) + content_hash = hash_segment(json.dumps([s["hash"] for s in segments])) + + # Quick dedup check (under lock) + async with self._lock: + if content_hash in self._content_index: + existing_key = self._content_index[content_hash] + if existing_key in self._entries: + entry = self._entries[existing_key] + entry.touch() + logger.info(f"Cache dedup hit: {entry.cache_key[:8]}… (content_hash={content_hash[:8]})") + return entry + + kv_data = b"" + size_bytes = 0 + token_count = 0 + + if model is not None: + # Real KV serialization: tokenize → decode → serialize + # Run blocking model ops in a thread to avoid blocking the event loop + try: + import time as _time + t0 = _time.perf_counter() + + def _prepare_kv(): + prompt = model._apply_chat_template(messages, add_generation_prompt=True) + tokens = model.tokenize(prompt, add_special=False, parse_special=True) + tc = len(tokens) + model._lib.llama_memory_clear(model._memory, True) + if not model._decode_batch(tokens): + raise RuntimeError(f"Failed to decode {tc} prefix tokens") + kv = model.state_seq_save(0) + return kv, tc + + kv_data, token_count = await asyncio.to_thread(_prepare_kv) + size_bytes = len(kv_data) + + t1 = _time.perf_counter() + logger.info( + f"Cache prepare (warm): {token_count} tokens, " + f"{size_bytes / 1024:.1f}KB KV state, " + f"{(t1 - t0) * 1000:.1f}ms" + ) + except Exception as e: + logger.error(f"KV serialization failed during prepare: {e}") + # Fall back to segment-only + kv_data = b"" + size_bytes = 0 + total_chars = sum(len(s.get("content", "")) for s in segments) + token_count = max(1, total_chars // 4) + logger.info(f"Falling back to segment-only prepare: ~{token_count} tokens") + else: + # Segment-only: estimate tokens, real KV saved after first completion + total_chars = sum(len(s.get("content", "")) for s in segments) + token_count = max(1, total_chars // 4) + logger.info(f"Cache prepare (segment-only): ~{token_count} tokens indexed") + + cache_key = _generate_cache_key() + entry = CacheEntry( + cache_key=cache_key, + model_id=model_id, + segments=segments, + content_hash=content_hash, + token_count=token_count, + pinned=pinned, + ttl=ttl if ttl is not None else (None if pinned else self._budget.default_ttl), + tier="ram", + kv_data=kv_data, + size_bytes=size_bytes, + ) + + async with self._lock: + # Re-check dedup inside lock to prevent TOCTOU race + if content_hash in self._content_index: + existing_key = self._content_index[content_hash] + if existing_key in self._entries: + existing = self._entries[existing_key] + existing.touch() + logger.info(f"Cache dedup hit (re-check): {existing.cache_key[:8]}…") + return existing + self._entries[cache_key] = entry + self._content_index[content_hash] = cache_key + self._enforce_budget() + + logger.info( + f"Prepared cache {cache_key[:8]}…: {token_count} tokens, " + f"{size_bytes / 1024:.1f}KB, warm={'yes' if kv_data else 'no'}, " + f"segments={[s['type'] for s in segments]}" + ) + return entry + + def lookup(self, cache_key: str) -> CacheEntry | None: + """Look up a cache entry by key. Returns None if not found or expired.""" + entry = self._entries.get(cache_key) + if entry is None: + return None + if entry.is_expired: + logger.debug(f"Cache {cache_key[:8]}… expired") + return None + return entry + + def validate_and_match( + self, + cache_key: str, + model_id: str, + messages: list[dict], + tools: list[dict] | None = None, + ) -> dict: + """Validate a cache key against incoming payload. + + Returns a dict with: + status: "hit" | "partial_hit" | "miss" + entry: CacheEntry or None + reusable_tokens: number of tokens that can be reused + invalidated_at: segment type where mismatch occurred + reason: human-readable reason + """ + entry = self.lookup(cache_key) + if entry is None: + self._total_misses += 1 + return { + "status": "miss", + "entry": None, + "reusable_tokens": 0, + "invalidated_at": None, + "reason": "cache_key_not_found", + } + + # Model must match + if entry.model_id != model_id: + self._total_misses += 1 + return { + "status": "miss", + "entry": None, + "reusable_tokens": 0, + "invalidated_at": "model", + "reason": f"model_mismatch: cached={entry.model_id}, requested={model_id}", + } + + # Compare segments + incoming_segments = hash_messages_segments(messages, tools) + match_count, invalidated_at = compare_segments(entry.segments, incoming_segments) + + if invalidated_at is None and match_count == len(entry.segments): + # Full hit + entry.touch() + self._total_hits += 1 + return { + "status": "hit", + "entry": entry, + "reusable_tokens": entry.token_count, + "invalidated_at": None, + "reason": "full_match", + } + + if match_count > 0: + # Partial hit — some leading segments match + entry.touch() + self._total_partial_hits += 1 + # Estimate reusable tokens (proportional to matched segments) + ratio = match_count / max(len(entry.segments), 1) + reusable_tokens = int(entry.token_count * ratio) + return { + "status": "partial_hit", + "entry": entry, + "reusable_tokens": reusable_tokens, + "invalidated_at": invalidated_at, + "reason": f"{invalidated_at}_changed", + } + + # Complete miss + self._total_misses += 1 + return { + "status": "miss", + "entry": None, + "reusable_tokens": 0, + "invalidated_at": invalidated_at, + "reason": f"{invalidated_at}_changed" if invalidated_at else "no_match", + } + + async def restore(self, entry: CacheEntry, model: Any, seq_id: int = 0) -> bool: + """Restore a cache entry's KV state into the model. + + If the entry has serialized KV data, loads it into the model context. + If no KV data (segment-only validation), returns True as a signal + that the prefix is validated — the caller can optimize accordingly. + + Returns True on success, False on failure. + """ + try: + # Segment-only entry (from prepare without serialization) + if not entry.kv_data and not entry.disk_path: + entry.touch() + logger.info( + f"Cache validated (segment-only): {entry.cache_key[:8]}…, " + f"{entry.token_count} prefix tokens confirmed unchanged" + ) + return True + + if entry.tier == "disk": + if entry.disk_path and Path(entry.disk_path).exists(): + # Use thread pool to avoid blocking event loop on large files + entry.kv_data = await asyncio.to_thread( + Path(entry.disk_path).read_bytes + ) + entry.tier = "ram" + async with self._lock: + self._enforce_budget() + else: + logger.warning(f"Cache {entry.cache_key[:8]}… disk path missing") + return False + + if not entry.kv_data: + logger.warning(f"Cache {entry.cache_key[:8]}… has no KV data") + return False + + # Run blocking model ops in a thread to avoid blocking the event loop + kv_data = entry.kv_data + + def _restore_kv(): + model.memory_seq_rm(seq_id) + return model.state_seq_load(kv_data, seq_id) + + consumed = await asyncio.to_thread(_restore_kv) + if consumed == 0: + logger.error(f"Failed to restore cache {entry.cache_key[:8]}…") + return False + + entry.touch() + logger.info( + f"Restored cache {entry.cache_key[:8]}…: {entry.token_count} tokens, " + f"{consumed} bytes into seq_id={seq_id}" + ) + return True + + except Exception as e: + logger.error(f"Failed to restore cache {entry.cache_key[:8]}…: {e}") + return False + + async def save_after_generation( + self, + model: Any, + model_id: str, + parent_key: str | None, + messages: list[dict], + tools: list[dict] | None = None, + seq_id: int = 0, + prompt_tokens: int = 0, + ) -> CacheEntry: + """Save the current KV state after generation as a new cache entry. + + This creates a new cache_key that includes the full conversation + (system + tools + all turns including the latest response). + The parent_key is informational only. + + Args: + prompt_tokens: Exact prompt token count from the model (for KV restore). + """ + segments = hash_messages_segments(messages, tools) + content_hash = hash_segment(json.dumps([s["hash"] for s in segments])) + + # Quick dedup check + async with self._lock: + if content_hash in self._content_index: + existing = self._entries.get(self._content_index[content_hash]) + if existing: + existing.touch() + return existing + + # Serialize current KV state (blocking model op → run in thread) + def _serialize_kv(): + kv = model.state_seq_save(seq_id) + tc = prompt_tokens + if tc <= 0: + try: + prompt_text = model._apply_chat_template( + [dict(m) if not isinstance(m, dict) else m for m in messages], + add_generation_prompt=True, + ) + toks = model.tokenize(prompt_text, add_special=False, parse_special=True) + tc = len(toks) + except Exception as e: + logger.warning(f"Failed to get exact token count: {e}, using estimate") + tc = 0 + for seg in segments: + tc += max(1, len(seg.get("content", "")) // 4) + return kv, tc + + kv_data, token_count = await asyncio.to_thread(_serialize_kv) + + cache_key = _generate_cache_key() + entry = CacheEntry( + cache_key=cache_key, + model_id=model_id, + segments=segments, + content_hash=content_hash, + token_count=token_count, + ttl=self._budget.default_ttl, + tier="ram", + kv_data=kv_data, + size_bytes=len(kv_data), + ) + + async with self._lock: + # Re-check dedup inside lock to prevent TOCTOU race + if content_hash in self._content_index: + existing = self._entries.get(self._content_index[content_hash]) + if existing: + existing.touch() + return existing + self._entries[cache_key] = entry + self._content_index[content_hash] = cache_key + self._enforce_budget() + + logger.info(f"Saved post-generation cache {cache_key[:8]}…: ~{token_count} tokens, {len(kv_data) / 1024:.1f}KB") + return entry + + # ── Cache Management ───────────────────────────────────────────────── + + def list_entries(self) -> list[dict]: + """List all cache entries.""" + return [e.to_dict() for e in self._entries.values()] + + def get_stats(self) -> dict: + """Get cache statistics.""" + entries = list(self._entries.values()) + ram_bytes = sum(e.size_bytes for e in entries if e.tier == "ram") + disk_bytes = sum(e.size_bytes for e in entries if e.tier == "disk") + total_requests = self._total_hits + self._total_misses + self._total_partial_hits + return { + "total_entries": len(entries), + "by_tier": { + "ram": len([e for e in entries if e.tier == "ram"]), + "disk": len([e for e in entries if e.tier == "disk"]), + }, + "ram_bytes": ram_bytes, + "disk_bytes": disk_bytes, + "total_hits": self._total_hits, + "total_partial_hits": self._total_partial_hits, + "total_misses": self._total_misses, + "hit_rate": self._total_hits / max(total_requests, 1), + "pinned_entries": len([e for e in entries if e.pinned]), + } + + def evict(self, cache_key: str) -> bool: + """Evict a specific cache entry. + + Note: Callers should hold self._lock when calling from async context, + or use evict_async() instead. + """ + entry = self._entries.pop(cache_key, None) + if entry is None: + return False + # Clean up content index + if entry.content_hash in self._content_index and self._content_index[entry.content_hash] == cache_key: + del self._content_index[entry.content_hash] + # Clean up disk file + if entry.disk_path: + with contextlib.suppress(Exception): + Path(entry.disk_path).unlink(missing_ok=True) + # Clear kv_data to free memory even if other references exist + entry.kv_data = b"" + logger.info(f"Evicted cache {cache_key[:8]}…") + return True + + async def evict_async(self, cache_key: str) -> bool: + """Thread-safe eviction of a cache entry.""" + async with self._lock: + return self.evict(cache_key) + + def gc(self) -> int: + """Run garbage collection. Returns number of entries removed. + + Note: Called from the GC background task. Uses dict snapshot + to avoid mutation during iteration. + """ + removed = 0 + expired_keys = [ + k for k, e in list(self._entries.items()) + if e.is_expired and not e.pinned + ] + for key in expired_keys: + self.evict(key) + removed += 1 + if removed: + logger.info(f"GC removed {removed} expired cache entries") + return removed + + def _enforce_budget(self) -> None: + """Enforce budget limits by demoting/evicting entries.""" + # Demote ram entries to disk if over budget + ram_entries = [e for e in self._entries.values() if e.tier == "ram" and not e.pinned] + ram_bytes = sum(e.size_bytes for e in self._entries.values() if e.tier == "ram") + + if ram_bytes > self._budget.max_ram_bytes: + # Sort by last_used (oldest first) + ram_entries.sort(key=lambda e: e.last_used) + for entry in ram_entries: + if ram_bytes <= self._budget.max_ram_bytes: + break + self._demote_to_disk(entry) + ram_bytes -= entry.size_bytes + + # Evict disk entries if over budget + disk_entries = [e for e in self._entries.values() if e.tier == "disk" and not e.pinned] + disk_bytes = sum(e.size_bytes for e in self._entries.values() if e.tier == "disk") + + if disk_bytes > self._budget.max_disk_bytes: + disk_entries.sort(key=lambda e: e.last_used) + for entry in disk_entries: + if disk_bytes <= self._budget.max_disk_bytes: + break + self.evict(entry.cache_key) + disk_bytes -= entry.size_bytes + + def _demote_to_disk(self, entry: CacheEntry) -> None: + """Move a ram entry to disk. + + Note: This performs synchronous disk I/O. When called from _enforce_budget() + under the async lock, it blocks the event loop briefly. + """ + if not entry.kv_data: + return + disk_path = self._cache_dir / f"{entry.cache_key}.kvstate" + try: + disk_path.write_bytes(entry.kv_data) + entry.disk_path = str(disk_path) + entry.kv_data = b"" # Free RAM + entry.tier = "disk" + logger.debug(f"Demoted cache {entry.cache_key[:8]}… to disk: {disk_path}") + except Exception as e: + logger.error(f"Failed to demote cache {entry.cache_key[:8]}… to disk: {e}") + + +# ── Background GC Task ────────────────────────────────────────────────────── + +_gc_task: asyncio.Task | None = None + + +async def _gc_loop(manager: KVCacheManager, interval: float = 60.0) -> None: + """Periodic GC sweep.""" + while True: + await asyncio.sleep(interval) + try: + async with manager._lock: + manager.gc() + except Exception as e: + logger.error(f"KV cache GC error: {e}") + + +def start_kv_cache_gc(manager: KVCacheManager) -> None: + """Start background GC task.""" + global _gc_task + if _gc_task is None or _gc_task.done(): + _gc_task = asyncio.create_task(_gc_loop(manager)) + + +async def stop_kv_cache_gc() -> None: + """Cancel the background GC task (call during shutdown).""" + global _gc_task + if _gc_task is not None and not _gc_task.done(): + _gc_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await _gc_task + logger.info("KV cache GC task stopped") + _gc_task = None diff --git a/runtimes/edge/utils/model_cache.py b/runtimes/edge/utils/model_cache.py new file mode 100644 index 000000000..a1b07d637 --- /dev/null +++ b/runtimes/edge/utils/model_cache.py @@ -0,0 +1,4 @@ +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.model_cache import ModelCache + +__all__ = ["ModelCache"] diff --git a/runtimes/edge/utils/model_format.py b/runtimes/edge/utils/model_format.py new file mode 100644 index 000000000..0426736e0 --- /dev/null +++ b/runtimes/edge/utils/model_format.py @@ -0,0 +1,24 @@ +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.model_format import ( + GGUF_QUANTIZATION_PREFERENCE_ORDER, + clear_format_cache, + detect_model_format, + get_gguf_file_path, + list_gguf_files, + parse_model_with_quantization, + parse_quantization_from_filename, + select_gguf_file, + select_gguf_file_with_logging, +) + +__all__ = [ + "GGUF_QUANTIZATION_PREFERENCE_ORDER", + "parse_model_with_quantization", + "parse_quantization_from_filename", + "select_gguf_file", + "select_gguf_file_with_logging", + "detect_model_format", + "list_gguf_files", + "get_gguf_file_path", + "clear_format_cache", +] diff --git a/runtimes/edge/utils/safe_home.py b/runtimes/edge/utils/safe_home.py new file mode 100644 index 000000000..b24399a54 --- /dev/null +++ b/runtimes/edge/utils/safe_home.py @@ -0,0 +1,4 @@ +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.safe_home import get_data_dir, safe_home + +__all__ = ["safe_home", "get_data_dir"] diff --git a/runtimes/edge/utils/thinking.py b/runtimes/edge/utils/thinking.py new file mode 100644 index 000000000..ddc2d66d4 --- /dev/null +++ b/runtimes/edge/utils/thinking.py @@ -0,0 +1,272 @@ +""" +Thinking/reasoning model utilities. + +Provides support for models like Qwen3 that use ... tags +for chain-of-thought reasoning. +""" + +import logging +import re +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class ParsedThinkingResponse: + """Parsed response from a thinking model.""" + + thinking: str | None # Content inside ... tags + content: ( + str # Final answer (content after or full response if no thinking) + ) + thinking_complete: bool # Whether thinking was properly closed with + + +def parse_thinking_response(response: str) -> ParsedThinkingResponse: + """Parse a response that may contain ... tags. + + Extracts thinking content and final answer from model responses. + Handles cases where thinking is incomplete (no closing tag). + + Args: + response: Raw model response text + + Returns: + ParsedThinkingResponse with thinking and content separated + + Examples: + >>> parse_thinking_response("Let me think...The answer is 42.") + ParsedThinkingResponse(thinking="Let me think...", content="The answer is 42.", thinking_complete=True) + + >>> parse_thinking_response("Still thinking...") + ParsedThinkingResponse(thinking="Still thinking...", content="", thinking_complete=False) + + >>> parse_thinking_response("No thinking here, just answer.") + ParsedThinkingResponse(thinking=None, content="No thinking here, just answer.", thinking_complete=True) + """ + # Pattern to match ... with content after + think_pattern = re.compile( + r"\s*(.*?)\s*\s*(.*)", + re.DOTALL | re.IGNORECASE, + ) + + match = think_pattern.match(response) + if match: + thinking = match.group(1).strip() + content = match.group(2).strip() + # Recursively clean any remaining tags from content + # (model sometimes outputs multiple closing tags in /no_think mode) + content = re.sub(r"^\s*\s*", "", content, flags=re.IGNORECASE) + return ParsedThinkingResponse( + thinking=thinking if thinking else None, + content=content.strip(), + thinking_complete=True, + ) + + # Check for stray anywhere (happens with /no_think mode) + # The model outputs empty think block or just closing tag(s) + # Remove ALL tags from the response + cleaned = re.sub(r"\s*", "", response, flags=re.IGNORECASE) + if cleaned != response: + # We removed some tags + return ParsedThinkingResponse( + thinking=None, + content=cleaned.strip(), + thinking_complete=True, + ) + + # Check for incomplete thinking (opening tag but no closing) + incomplete_pattern = re.compile(r"\s*(.*)", re.DOTALL | re.IGNORECASE) + incomplete_match = incomplete_pattern.match(response) + if incomplete_match: + thinking = incomplete_match.group(1).strip() + return ParsedThinkingResponse( + thinking=thinking if thinking else None, + content="", + thinking_complete=False, + ) + + # No thinking tags at all + return ParsedThinkingResponse( + thinking=None, + content=response.strip(), + thinking_complete=True, + ) + + +def inject_thinking_control( + messages: list[dict], + enable_thinking: bool, +) -> list[dict]: + """Inject thinking control into messages using Qwen's soft switch. + + Qwen3 models support /think and /no_think soft switches in prompts + to control whether the model uses thinking mode. + + Handles both text-only messages (content is string) and multimodal + messages (content is list of content parts). + + Args: + messages: List of chat messages + enable_thinking: True to force thinking, False to disable + + Returns: + Modified messages list with thinking control injected + """ + + # Make a copy to avoid modifying the original + messages = [dict(m) for m in messages] + control_token = "/think" if enable_thinking else "/no_think" + + # Find the last user message and append the control + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + content = messages[i].get("content", "") + + # Check for simple string content FIRST (most common case) + # This avoids triggering iteration/validation on complex types + if isinstance(content, str): + if "/think" not in content and "/no_think" not in content: + messages[i]["content"] = f"{content} {control_token}" + else: + # Handle multimodal messages (content is a list/iterable of parts) + # Convert to list to safely iterate without triggering pydantic validation + try: + content_list = list(content) if not isinstance(content, list) else content + except Exception: + # If we can't convert to list, just append control as new content + messages[i]["content"] = [ + {"type": "text", "text": control_token} + ] + break + + # Check if any text parts already contain control tokens + has_control = False + for part in content_list: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text", "") + if "/think" in text or "/no_think" in text: + has_control = True + break + + if not has_control: + # Append control token as a new text part + content_list = list(content_list) # Make a copy + content_list.append({"type": "text", "text": control_token}) + messages[i]["content"] = content_list + break + + return messages + + +class ThinkingBudgetProcessor: + """Logits processor that enforces a thinking token budget. + + When the thinking budget is reached, this processor forces the model + to generate and proceed to the answer. + + This is used with llama-cpp's logits_processor parameter. + Uses numpy arrays. + """ + + def __init__( + self, + llama, + max_thinking_tokens: int, + think_end_tokens: list[int] | None = None, + ): + """Initialize the thinking budget processor. + + Args: + llama: The Llama instance (for tokenization) + max_thinking_tokens: Maximum tokens to allow for thinking + think_end_tokens: Token IDs for (auto-detected if None) + """ + self.llama = llama + self.max_thinking_tokens = max_thinking_tokens + self.thinking_tokens = 0 # Only counts tokens INSIDE + self.in_thinking = False + self.thinking_ended = False + self.forcing_end = False # True while forcing sequence + + # Try to get the token IDs for + if think_end_tokens is None: + try: + # Tokenize to get its token IDs + self.think_end_tokens = llama.tokenize( + b"", add_bos=False, special=True + ) + except Exception: + # Fallback - will use soft switch instead + self.think_end_tokens = None + else: + self.think_end_tokens = think_end_tokens + + self._force_token_idx = 0 + + def __call__(self, input_ids, scores): + """Process logits to enforce thinking budget. + + Args: + input_ids: numpy array of token IDs generated so far + scores: numpy array of logits for next token (modified in-place) + + Returns: + Modified scores array (numpy) + """ + import numpy as np + + # Convert to numpy if needed (for compatibility) + if not isinstance(scores, np.ndarray): + scores = np.array(scores) + + # Check current state by looking at generated text + if not self.thinking_ended and not self.forcing_end: + try: + # Convert input_ids to list if numpy array + ids = ( + input_ids.tolist() + if hasattr(input_ids, "tolist") + else list(input_ids) + ) + text = self.llama.detokenize(ids).decode("utf-8", errors="ignore") + + if "" in text.lower() and not self.in_thinking: + self.in_thinking = True + if "" in text.lower(): + self.thinking_ended = True + self.in_thinking = False + except Exception: + # Per-token hook — suppress to avoid breaking generation + logger.debug("Think-tag detection failed in logits processor", exc_info=True) + + # Count tokens only while in thinking mode + if self.in_thinking and not self.thinking_ended: + self.thinking_tokens += 1 + + # If in thinking, over budget, and have end tokens - start forcing + if ( + self.in_thinking + and not self.thinking_ended + and self.thinking_tokens >= self.max_thinking_tokens + and self.think_end_tokens + and not self.forcing_end + ): + self.forcing_end = True + + # If we are actively forcing the end token sequence + if self.forcing_end and self._force_token_idx < len(self.think_end_tokens): + target_token = self.think_end_tokens[self._force_token_idx] + self._force_token_idx += 1 + + # Set all logits to -inf except the target token + scores[:] = -np.inf + if target_token < len(scores): + scores[target_token] = 0.0 + elif self.forcing_end and self._force_token_idx >= len(self.think_end_tokens): + # Finalize state when forcing completes + self.thinking_ended = True + self.in_thinking = False + + return scores diff --git a/runtimes/edge/utils/token_counter.py b/runtimes/edge/utils/token_counter.py new file mode 100644 index 000000000..8581d1c81 --- /dev/null +++ b/runtimes/edge/utils/token_counter.py @@ -0,0 +1,153 @@ +"""Token counting utilities for context management. + +Provides token counting functionality using the model's tokenizer +for accurate context window management. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from llamafarm_llama import Llama + +logger = logging.getLogger(__name__) + + +class TokenCounter: + """Token counter using the model's tokenizer. + + Provides methods for counting tokens in text and messages, + enabling accurate context window management. + """ + + # Estimated token overhead per message for role markers and formatting + MESSAGE_OVERHEAD = 4 + + # Chat template overhead factor (10% buffer for template markers) + TEMPLATE_OVERHEAD_FACTOR = 1.10 + + def __init__(self, llama: Llama): + """Initialize token counter with a Llama model instance. + + Args: + llama: A loaded Llama model instance with tokenize() method. + """ + self._llama = llama + + def count_tokens(self, text: str) -> int: + """Count tokens in a text string. + + Args: + text: The text to tokenize. + + Returns: + Number of tokens in the text. + """ + if not text: + return 0 + + tokens = self._llama.tokenize(text, add_special=False, parse_special=True) + return len(tokens) + + def count_message_tokens(self, message: dict) -> int: + """Count tokens for a single message including role overhead. + + The overhead accounts for role markers (e.g., "<|user|>") and + other formatting added by chat templates. + + Handles both text-only messages (content is string) and multimodal + messages (content is list of content parts). + + Args: + message: A message dict with 'role' and 'content' keys. + + Returns: + Estimated token count for the message. + """ + content = message.get("content") or "" + + # Handle multimodal messages (content is a list of parts) + if isinstance(content, list): + total_tokens = 0 + for part in content: + if isinstance(part, dict): + part_type = part.get("type", "") + if part_type == "text": + # Count tokens in text parts + text = part.get("text", "") + total_tokens += self.count_tokens(text) + elif part_type == "input_audio": + # Audio parts don't contribute text tokens + # Use a small estimate for the audio marker/placeholder + total_tokens += 10 + elif part_type == "image_url": + # Image parts - use a moderate estimate + total_tokens += 50 + # Skip other unknown types + return total_tokens + self.MESSAGE_OVERHEAD + + # Handle simple string content + return self.count_tokens(content) + self.MESSAGE_OVERHEAD + + def count_messages_tokens(self, messages: list[dict]) -> int: + """Count total tokens for a list of messages. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + Total token count for all messages. + """ + return sum(self.count_message_tokens(m) for m in messages) + + def estimate_prompt_tokens( + self, + messages: list[dict], + include_template_overhead: bool = True, + ) -> int: + """Estimate total prompt tokens including chat template overhead. + + This is an estimate because the exact token count after template + application depends on the specific model's chat template. The + 10% overhead is a conservative buffer that works for most templates. + + Args: + messages: List of message dicts. + include_template_overhead: Whether to add 10% overhead for chat + template markers (BOS token, role tokens, etc.). + + Returns: + Estimated token count for the prompt. + """ + base_tokens = self.count_messages_tokens(messages) + + if include_template_overhead: + return int(base_tokens * self.TEMPLATE_OVERHEAD_FACTOR) + + return base_tokens + + def truncate_to_tokens(self, text: str, max_tokens: int) -> str: + """Truncate text to a maximum number of tokens. + + Useful for truncating long tool results or code blocks. + + Args: + text: The text to truncate. + max_tokens: Maximum number of tokens to keep. + + Returns: + Truncated text (may be the original if within limits). + """ + if not text: + return text + + tokens = self._llama.tokenize(text, add_special=False, parse_special=True) + + if len(tokens) <= max_tokens: + return text + + # Truncate tokens and detokenize + truncated_tokens = tokens[:max_tokens] + return self._llama.detokenize(truncated_tokens) diff --git a/runtimes/edge/utils/tool_calling.py b/runtimes/edge/utils/tool_calling.py new file mode 100644 index 000000000..54180b773 --- /dev/null +++ b/runtimes/edge/utils/tool_calling.py @@ -0,0 +1,555 @@ +""" +Prompt-based tool calling utilities. + +This module provides functions for injecting tool definitions into prompts +and detecting tool calls in model outputs using XML tags. +""" + +from __future__ import annotations + +import copy +import json +import logging +import re + +logger = logging.getLogger(__name__) + +# Pre-compiled regex patterns for better performance +# Pattern to extract tool calls from ... tags +TOOL_CALL_PATTERN = re.compile(r"(.*?)", re.DOTALL) + +# Pattern to strip tool call tags from content +TOOL_CALL_STRIP_PATTERN = re.compile(r".*?", re.DOTALL) + +# Pattern to extract tool name from partial JSON +TOOL_NAME_PATTERN = re.compile(r'"name"\s*:\s*"([^"]+)"') + + +# ============================================================================= +# Prompt templates for different tool_choice modes +# ============================================================================= + +# tool_choice="auto" (default) - model may call tools if helpful +TOOLS_PREFIX_AUTO = """ + +You may call one or more tools to assist with the user query. +You are provided with function signatures within XML tags: + +""" + +TOOLS_SUFFIX_AUTO = """ +For each tool call, return a json object with function name and arguments within XML tags: +{"name": , "arguments": }. +If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. +""" + +# tool_choice="required" - model MUST call at least one tool +TOOLS_PREFIX_REQUIRED = """ + +You MUST call one or more tools to respond to the user query. Do not respond with text alone. +You are provided with function signatures within XML tags: + +""" + +TOOLS_SUFFIX_REQUIRED = """ +You MUST use at least one of these tools. Return a json object with function name and arguments within XML tags: +{"name": , "arguments": }. +""" + +# tool_choice={"type": "function", "function": {"name": "X"}} - model MUST call specific function +TOOLS_PREFIX_SPECIFIC = """ + +You MUST call the function "{function_name}" to respond to this query. +The function is defined within XML tags: + +""" + +TOOLS_SUFFIX_SPECIFIC = """ +You MUST call the "{function_name}" function. Return a json object with the function name and arguments within XML tags: +{{"name": "{function_name}", "arguments": }}. +""" + +# Legacy aliases for backward compatibility +TOOLS_SYSTEM_MESSAGE_PREFIX = TOOLS_PREFIX_AUTO +TOOLS_SYSTEM_MESSAGE_SUFFIX = TOOLS_SUFFIX_AUTO + + +def format_tool_for_prompt(tool: dict) -> str: + """Format a single tool definition for injection into the prompt. + + Args: + tool: OpenAI-format tool definition with 'type' and 'function' keys. + + Returns: + JSON string representation of the tool. + """ + return json.dumps(tool, ensure_ascii=False) + + +def validate_tool_schema(tool: dict) -> list[str]: + """Validate a tool definition schema. + + Args: + tool: Tool definition in OpenAI format. + + Returns: + List of validation error messages (empty if valid). + """ + errors = [] + + if not isinstance(tool, dict): + errors.append(f"Tool must be a dict, got {type(tool).__name__}") + return errors + + # Check required top-level fields + if "type" not in tool: + errors.append("Tool missing required 'type' field") + elif tool["type"] != "function": + errors.append(f"Tool type must be 'function', got '{tool['type']}'") + + if "function" not in tool: + errors.append("Tool missing required 'function' field") + return errors + + func = tool["function"] + if not isinstance(func, dict): + errors.append(f"Tool 'function' must be a dict, got {type(func).__name__}") + return errors + + # Check required function fields + if "name" not in func: + errors.append("Tool function missing required 'name' field") + elif not isinstance(func["name"], str) or not func["name"]: + errors.append("Tool function 'name' must be a non-empty string") + + # Check optional but commonly expected fields + if "parameters" in func: + params = func["parameters"] + if not isinstance(params, dict): + errors.append( + f"Tool parameters must be a dict, got {type(params).__name__}" + ) + + return errors + + +def parse_tool_choice(tool_choice: str | dict | None) -> tuple[str, str | None]: + """Parse tool_choice into a mode and optional function name. + + Args: + tool_choice: Tool choice parameter from the API request. + - None or "auto": Model decides whether to call tools + - "none": Model should not call any tools + - "required": Model must call at least one tool + - {"type": "function", "function": {"name": "X"}}: Model must call function X + + Returns: + Tuple of (mode, function_name) where mode is one of: + "auto", "none", "required", "specific" + and function_name is set only when mode is "specific". + """ + if tool_choice is None or tool_choice == "auto": + return ("auto", None) + elif tool_choice == "none": + return ("none", None) + elif tool_choice == "required": + return ("required", None) + elif isinstance(tool_choice, dict): + # Handle {"type": "function", "function": {"name": "X"}} + if tool_choice.get("type") == "function": + func_info = tool_choice.get("function", {}) + func_name = func_info.get("name") + if func_name: + return ("specific", func_name) + # Fallback if dict format is unexpected + logger.warning( + f"Unexpected tool_choice dict format: {tool_choice}, using 'auto'" + ) + return ("auto", None) + else: + logger.warning(f"Unknown tool_choice value: {tool_choice}, using 'auto'") + return ("auto", None) + + +def inject_tools_into_messages( + messages: list[dict], + tools: list[dict], + tool_choice: str | dict | None = None, +) -> list[dict]: + """Inject tool definitions into the system message. + + If no system message exists, one is created. The tools are appended + to the system message content using XML tags. + + Args: + messages: List of chat messages (will not be modified). + tools: List of tool definitions in OpenAI format. + tool_choice: Tool choice strategy: + - None or "auto": Model may call tools (default) + - "none": Model should not call tools (returns messages unchanged) + - "required": Model must call at least one tool + - {"type": "function", "function": {"name": "X"}}: Must call specific function + + Returns: + New list of messages with tools injected into system message. + """ + if not tools: + return messages + + # Validate tool schemas before injection + valid_tools = [] + for i, tool in enumerate(tools): + errors = validate_tool_schema(tool) + if errors: + tool_name = tool.get("function", {}).get("name", f"tool[{i}]") + logger.warning( + f"Skipping malformed tool '{tool_name}': {'; '.join(errors)}" + ) + else: + valid_tools.append(tool) + + if not valid_tools: + logger.warning("No valid tools after validation, returning original messages") + return messages + + tools = valid_tools + + # Parse tool_choice to determine mode + mode, specific_func = parse_tool_choice(tool_choice) + + # "none" means don't inject tools at all + if mode == "none": + logger.debug("tool_choice='none', skipping tool injection") + return messages + + # Deep copy to avoid modifying original + messages = copy.deepcopy(messages) + + # Filter tools if a specific function is requested + tools_to_inject = tools + if mode == "specific" and specific_func: + tools_to_inject = [ + t for t in tools if t.get("function", {}).get("name") == specific_func + ] + if not tools_to_inject: + logger.warning( + f"tool_choice specified function '{specific_func}' but it was not found " + f"in provided tools. Available: {[t.get('function', {}).get('name') for t in tools]}" + ) + # Fall back to auto mode with all tools + mode = "auto" + tools_to_inject = tools + + # Select prefix and suffix based on mode + if mode == "required": + prefix = TOOLS_PREFIX_REQUIRED + suffix = TOOLS_SUFFIX_REQUIRED + elif mode == "specific" and specific_func: + prefix = TOOLS_PREFIX_SPECIFIC.format(function_name=specific_func) + suffix = TOOLS_SUFFIX_SPECIFIC.format(function_name=specific_func) + else: # "auto" or fallback + prefix = TOOLS_PREFIX_AUTO + suffix = TOOLS_SUFFIX_AUTO + + # Build tools section + tools_section = prefix + for tool in tools_to_inject: + tools_section += f"{format_tool_for_prompt(tool)}\n" + tools_section += suffix + + # Find system message and append tools + system_found = False + for msg in messages: + if msg.get("role") == "system": + content = msg.get("content", "") + if isinstance(content, str): + msg["content"] = content + tools_section + system_found = True + break + # Non-string content (e.g., multimodal) - can't inject tools here + # Continue searching for a string-content system message + + # If no system message, create one + if not system_found: + messages.insert(0, {"role": "system", "content": tools_section.strip()}) + + return messages + + +def detect_tool_call_in_content(content: str) -> list[tuple[str, str]] | None: + """Extract tool calls from content using XML tags. + + Looks for ... patterns and extracts + the tool name and arguments from each. + + Args: + content: The model's response content. + + Returns: + List of (tool_name, arguments_json) tuples, or None if no tool calls found. + """ + if not content: + return None + + matches = TOOL_CALL_PATTERN.findall(content) + + if not matches: + return None + + results = [] + parse_errors = [] + for i, match in enumerate(matches): + try: + # Parse the JSON inside the tool_call tags + tool_call_json = json.loads(match.strip()) + tool_name = tool_call_json.get("name") + tool_args = tool_call_json.get("arguments", {}) + + if tool_name: + # Re-serialize arguments to ensure consistent JSON format + args_json = json.dumps(tool_args, ensure_ascii=False) + results.append((tool_name, args_json)) + else: + parse_errors.append(f"Tool call {i + 1}: missing 'name' field") + except json.JSONDecodeError as e: + parse_errors.append( + f"Tool call {i + 1}: JSON parse error - {e}, content: {match[:100]!r}" + ) + + # Log summary of parsing results + if parse_errors: + logger.error( + f"Failed to parse {len(parse_errors)}/{len(matches)} tool call(s): " + f"{'; '.join(parse_errors)}" + ) + + return results if results else None + + +def detect_probable_tool_call(content: str) -> bool: + """Check if content likely contains an incomplete tool call. + + Used during streaming to detect when we should start buffering + instead of emitting tokens. + + Args: + content: Accumulated content so far. + + Returns: + True if content contains an opening tag. + """ + return "" in content + + +def strip_tool_call_from_content(content: str) -> str: + """Remove tool call XML tags from content. + + Args: + content: The model's response content. + + Returns: + Content with tool call tags removed. + """ + return TOOL_CALL_STRIP_PATTERN.sub("", content).strip() + + +# ============================================================================= +# Incremental streaming utilities +# ============================================================================= + + +def extract_tool_name_from_partial(content: str) -> str | None: + """Extract tool name from incomplete tool call JSON. + + Used during streaming to detect the tool name before the entire + tool call JSON is complete. This enables emitting the initial + tool call chunk early. + + Looks for patterns like: + - {"name": "get_weather" + - {"name":"get_weather", + + Args: + content: Accumulated content that may contain a partial tool call. + + Returns: + Tool name if found and complete, None otherwise. + """ + if not content or "" not in content: + return None + + # Find the start of the tool call JSON + start_idx = content.find("") + if start_idx == -1: + return None + + # Extract everything after + json_start = start_idx + len("") + partial_json = content[json_start:] + + # Use regex to extract a complete "name" value + # Matches: "name": "value" or "name":"value" + # The name value must be complete (closing quote found) + match = TOOL_NAME_PATTERN.search(partial_json) + + if match: + return match.group(1) + + return None + + +def extract_arguments_progress(content: str) -> tuple[int, str] | None: + """Extract the arguments JSON string progress from a partial tool call. + + Used during streaming to extract how much of the "arguments" value + we have so far, enabling incremental streaming of arguments. + + Args: + content: Accumulated content containing a partial tool call. + + Returns: + Tuple of (start_position, arguments_so_far) where start_position + is the character index where arguments value begins in the content, + and arguments_so_far is the accumulated arguments string. + Returns None if arguments section not yet started. + """ + if not content or "" not in content: + return None + + # Find the start of the tool call JSON + tool_start = content.find("") + if tool_start == -1: + return None + + json_start = tool_start + len("") + partial_json = content[json_start:] + + # Find "arguments": or "arguments" : + args_pattern = r'"arguments"\s*:\s*' + match = re.search(args_pattern, partial_json) + + if not match: + return None + + # Position where the arguments value starts (after the colon and whitespace) + args_value_start = json_start + match.end() + + # Extract everything from there + remaining = content[args_value_start:] + + # Track brace depth to find the end of the arguments JSON value + # Arguments is a JSON object, so we need to find where it closes + args_content = _extract_json_value(remaining) + + if not args_content: + return None + + return (args_value_start, args_content) + + +def _extract_json_value(content: str) -> str: + """Extract a JSON value (object or array) from the start of content. + + Tracks brace/bracket depth to find where the JSON value ends. + Handles incomplete JSON by returning what we have so far. + + Args: + content: String starting with a JSON value. + + Returns: + The JSON value string (possibly incomplete). + """ + if not content: + return "" + + content = content.strip() + if not content: + return "" + + # Determine the opening bracket type + if content[0] == "{": + open_char, close_char = "{", "}" + elif content[0] == "[": + open_char, close_char = "[", "]" + else: + # Not a JSON object/array, might be a primitive + # For tool calls, arguments should always be an object + return content + + depth = 0 + in_string = False + escape_next = False + end_pos = len(content) + + for i, char in enumerate(content): + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if in_string: + continue + + if char == open_char: + depth += 1 + elif char == close_char: + depth -= 1 + if depth == 0: + # Found the matching closing bracket + end_pos = i + 1 + break + + # Return the JSON value (complete or partial) + result = content[:end_pos] + + # Clean up any trailing content after the closing bracket + # (like the closing brace of the outer object or ) + return result + + +def is_tool_call_complete(content: str) -> bool: + """Check if content contains a complete tool call with closing tag. + + Args: + content: Accumulated content that may contain a tool call. + + Returns: + True if a complete ... is found. + """ + if not content: + return False + + return "" in content + + +def get_tool_call_content_after_tag(content: str) -> str | None: + """Extract the content inside ... tags. + + Args: + content: Content containing tool call tags. + + Returns: + The content between the tags, or None if not found. + """ + if not content or "" not in content: + return None + + start_idx = content.find("") + if start_idx == -1: + return None + + json_start = start_idx + len("") + end_idx = content.find("", json_start) + + if end_idx == -1: + # No closing tag yet, return everything after opening tag + return content[json_start:] + + return content[json_start:end_idx] diff --git a/runtimes/universal/tests/test_model_format.py b/runtimes/universal/tests/test_model_format.py index 3b72f0666..c6202668b 100644 --- a/runtimes/universal/tests/test_model_format.py +++ b/runtimes/universal/tests/test_model_format.py @@ -8,8 +8,8 @@ class TestDetectModelFormat: """Test model format detection (runtime-specific).""" - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_detect_model_format_gguf(self, mock_hf_api_class, mock_check_local_cache): """Test detecting GGUF format.""" from utils.model_format import clear_format_cache, detect_model_format @@ -35,8 +35,8 @@ def test_detect_model_format_gguf(self, mock_hf_api_class, mock_check_local_cach # Verify assert result == "gguf" - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_detect_model_format_transformers( self, mock_hf_api_class, mock_check_local_cache ): @@ -64,8 +64,8 @@ def test_detect_model_format_transformers( # Verify assert result == "transformers" - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_detect_model_format_strips_quantization_suffix( self, mock_hf_api_class, mock_check_local_cache ): @@ -100,8 +100,8 @@ def test_detect_model_format_strips_quantization_suffix( # Verify correct format was detected assert result == "gguf" - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_caching_with_quantization_suffix( self, mock_hf_api_class, mock_check_local_cache ): @@ -139,8 +139,8 @@ def test_caching_with_quantization_suffix( assert result3 == "gguf" assert mock_api.list_repo_files.call_count == 1 # Still 1, cache was used - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_detect_model_format_uses_local_cache( self, mock_hf_api_class, mock_check_local_cache ): @@ -175,8 +175,8 @@ def test_detect_model_format_uses_local_cache( # Verify HF API was NOT called (used local cache instead) mock_api.list_repo_files.assert_not_called() - @patch("utils.model_format._check_local_cache_for_model") - @patch("utils.model_format.HfApi") + @patch("llamafarm_common.model_format._check_local_cache_for_model") + @patch("llamafarm_common.model_format.HfApi") def test_detect_model_format_local_cache_transformers( self, mock_hf_api_class, mock_check_local_cache ): diff --git a/runtimes/universal/utils/device.py b/runtimes/universal/utils/device.py index 5f85288ec..c6a328841 100644 --- a/runtimes/universal/utils/device.py +++ b/runtimes/universal/utils/device.py @@ -1,195 +1,9 @@ -""" -Device detection and optimization utilities. - -PyTorch is optional - this module provides fallback behavior for GGUF-only -deployments where torch is not installed. llama.cpp has its own GPU detection -independent of PyTorch. -""" - -from __future__ import annotations - -import logging -import platform -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import torch as torch_type - -logger = logging.getLogger(__name__) - -# Cached torch module reference (lazy loaded) -_torch: torch_type | None = None -_torch_available: bool | None = None - - -def _get_torch() -> torch_type | None: - """Lazy-load torch module. Returns None if not installed.""" - global _torch, _torch_available - - if _torch_available is None: - try: - import torch - - _torch = torch - _torch_available = True - logger.debug(f"PyTorch {torch.__version__} loaded successfully") - except ImportError: - _torch = None - _torch_available = False - logger.info("PyTorch not installed - encoder models will not be available") - - return _torch - - -def is_torch_available() -> bool: - """Check if PyTorch is available without importing it.""" - _get_torch() - return _torch_available or False - - -def get_optimal_device() -> str: - """ - Detect the optimal device for the current platform. - - Returns: - str: Device name ("cuda", "mps", or "cpu") - - Note: - If PyTorch is not installed, always returns "cpu". - This allows GGUF models to still use GPU via llama.cpp's own detection. - """ - import os - - # Allow forcing CPU via environment variable - force_cpu = os.environ.get("TRANSFORMERS_FORCE_CPU", "").lower() in ( - "1", - "true", - "yes", - ) - if force_cpu: - logger.info("Forcing CPU device (TRANSFORMERS_FORCE_CPU=1)") - return "cpu" - - # Try to use PyTorch for device detection - torch = _get_torch() - if torch is None: - logger.info("PyTorch not available - using CPU for encoder models") - return "cpu" - - # Check for CUDA - if torch.cuda.is_available(): - logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}") - return "cuda" - - # Check for MPS (Apple Silicon) - # Note: MPS has a 4GB temporary buffer limit which can cause issues with some models - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - # Check if user wants to skip MPS due to known limitations - skip_mps = os.environ.get("TRANSFORMERS_SKIP_MPS", "").lower() in ( - "1", - "true", - "yes", - ) - if skip_mps: - logger.info("Skipping MPS (TRANSFORMERS_SKIP_MPS=1), using CPU") - return "cpu" - logger.info("MPS (Apple Silicon) available") - logger.warning( - "MPS has a 4GB temporary buffer limit. Set TRANSFORMERS_SKIP_MPS=1 to use CPU if you encounter errors." - ) - return "mps" - - # Fallback to CPU - logger.info("Using CPU (no GPU acceleration)") - return "cpu" - - -def get_device_info() -> dict: - """ - Get detailed device information. - - Returns: - dict: Device information including platform, acceleration, memory - """ - device = get_optimal_device() - torch = _get_torch() - - info = { - "device": device, - "platform": platform.system(), - "python_version": platform.python_version(), - "torch_version": torch.__version__ if torch else "not installed", - "torch_available": torch is not None, - } - - if torch is not None: - if device == "cuda": - gpu_count = torch.cuda.device_count() - # Primary GPU info (backward compatible) - free_0, total_0 = torch.cuda.mem_get_info(0) - info.update( - { - "gpu_name": torch.cuda.get_device_name(0), - "gpu_memory_total": total_0, - "gpu_memory_free": free_0, - "gpu_memory_allocated": torch.cuda.memory_allocated(0), - "gpu_count": gpu_count, - } - ) - # Per-GPU details for multi-GPU systems - if gpu_count > 1: - gpus = [] - for i in range(gpu_count): - free, total = torch.cuda.mem_get_info(i) - gpus.append( - { - "index": i, - "name": torch.cuda.get_device_name(i), - "memory_total": total, - "memory_free": free, - "memory_allocated": torch.cuda.memory_allocated(i), - } - ) - info["gpus"] = gpus - elif device == "mps": - info.update( - { - "gpu_name": "Apple Silicon (MPS)", - "architecture": platform.machine(), - } - ) - - return info - - -def get_gguf_gpu_layers() -> int: - """ - Get the number of GPU layers to use for GGUF models. - - IMPORTANT: llama.cpp has its own GPU detection (CUDA, Metal, Vulkan, etc.) - that is independent of PyTorch. We should always try to use GPU layers (-1) - and let llama.cpp fall back to CPU if no GPU backend is available. - This allows users with CPU-only PyTorch but GPU llama.cpp to get acceleration. - - Returns: - int: Number of GPU layers (-1 for all layers on GPU, 0 for CPU only) - """ - import os - - force_cpu = os.environ.get("LLAMAFARM_GGUF_FORCE_CPU", "").lower() in ( - "1", - "true", - "yes", - ) - - if force_cpu: - logger.info("Configuring for CPU-only inference (LLAMAFARM_GGUF_FORCE_CPU=1)") - return 0 - - # Use all layers on GPU - llama.cpp will use whatever backend is available - # (CUDA, Metal, Vulkan, etc.) and fall back to CPU if none are available - logger.info( - "Configuring for GPU acceleration (all layers on GPU, llama.cpp will " - "auto-detect available backends)" - ) - return -1 +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.device import ( + get_device_info, + get_gguf_gpu_layers, + get_optimal_device, + is_torch_available, +) + +__all__ = ["get_optimal_device", "get_device_info", "is_torch_available", "get_gguf_gpu_layers"] diff --git a/runtimes/universal/utils/model_cache.py b/runtimes/universal/utils/model_cache.py index 0e7b832f3..a1b07d637 100644 --- a/runtimes/universal/utils/model_cache.py +++ b/runtimes/universal/utils/model_cache.py @@ -1,188 +1,4 @@ -"""TTL-based model cache using cachetools. +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.model_cache import ModelCache -Provides a cache that: -- Automatically tracks last access time -- Refreshes TTL on access (not just on write) -- Supports async cleanup callbacks before expiration -""" - -import time -from collections.abc import Iterator -from typing import Generic, TypeVar - -from cachetools import TTLCache - -T = TypeVar("T") - - -class ModelCache(Generic[T]): - """TTL-based cache for models with async cleanup support. - - Uses cachetools.TTLCache internally but refreshes TTL on read access - (not just write), and provides methods for async cleanup before items - expire. - - This is designed for ML model caching where: - - Models should stay loaded while being actively used - - Idle models should be unloaded after a timeout - - Unloading requires calling an async cleanup method - - Example: - cache = ModelCache[BaseModel](ttl=300) # 5 minute TTL - - # Set a model - cache["encoder:model-id"] = model - - # Get model (refreshes TTL) - model = cache.get("encoder:model-id") - - # In cleanup task: - for key, model in cache.pop_expired(): - await model.unload() - """ - - def __init__(self, ttl: float, maxsize: int = 1000): - """Initialize the cache. - - Args: - ttl: Time-to-live in seconds. Items are considered expired - after this many seconds of inactivity (no read or write). - maxsize: Maximum number of items to store. - """ - self._ttl = ttl - self._maxsize = maxsize - # Internal TTLCache with very long TTL - we manage expiry ourselves - # to support async callbacks before removal - self._cache: TTLCache[str, T] = TTLCache(maxsize=maxsize, ttl=ttl * 10) - # Track access times ourselves for TTL-on-read behavior - self._timer = time.monotonic - self._access: dict[str, float] = {} - - @property - def ttl(self) -> float: - """Get the TTL in seconds.""" - return self._ttl - - def __contains__(self, key: str) -> bool: - return key in self._cache - - def __len__(self) -> int: - return len(self._cache) - - def __iter__(self) -> Iterator[str]: - return iter(self._cache) - - def get(self, key: str, default: T | None = None) -> T | None: - """Get item and refresh its TTL. - - Args: - key: Cache key - default: Value to return if key not found - - Returns: - The cached item, or default if not found - """ - if key not in self._cache: - return default - self._access[key] = self._timer() - return self._cache[key] - - def __getitem__(self, key: str) -> T: - """Get item and refresh TTL. Raises KeyError if not found.""" - if key not in self._cache: - raise KeyError(key) - self._access[key] = self._timer() - return self._cache[key] - - def __setitem__(self, key: str, value: T) -> None: - """Set item with fresh TTL.""" - self._cache[key] = value - self._access[key] = self._timer() - - def __delitem__(self, key: str) -> None: - """Remove item from cache.""" - del self._cache[key] - self._access.pop(key, None) - - def pop(self, key: str, *args) -> T: - """Remove and return item. - - Args: - key: Cache key - *args: Optional default value - - Returns: - The removed item, or default if provided and key not found - """ - self._access.pop(key, None) - return self._cache.pop(key, *args) - - def keys(self): - """Return view of cache keys.""" - return self._cache.keys() - - def values(self): - """Return view of cache values.""" - return self._cache.values() - - def items(self): - """Return view of cache items.""" - return self._cache.items() - - def clear(self) -> None: - """Clear all items from cache.""" - self._cache.clear() - self._access.clear() - - def get_idle_time(self, key: str) -> float | None: - """Get seconds since last access for a key. - - Args: - key: Cache key - - Returns: - Seconds since last access, or None if key not found - """ - if key not in self._access: - return None - return self._timer() - self._access[key] - - def is_expired(self, key: str) -> bool: - """Check if an item has exceeded its TTL. - - Args: - key: Cache key - - Returns: - True if item exists and is expired, False otherwise - """ - idle_time = self.get_idle_time(key) - return idle_time is not None and idle_time > self._ttl - - def get_expired_keys(self) -> list[str]: - """Get list of keys that have exceeded their TTL. - - Returns: - List of expired cache keys - """ - now = self._timer() - cutoff = now - self._ttl - return [k for k, t in self._access.items() if t < cutoff] - - def pop_expired(self) -> list[tuple[str, T]]: - """Remove and return all expired items. - - This is the main method for cleanup tasks. It returns all expired - items so the caller can perform async cleanup (like calling unload()). - - Returns: - List of (key, value) tuples for expired items - """ - expired_keys = self.get_expired_keys() - result = [] - for key in expired_keys: - if key in self._cache: - value = self._cache.pop(key) - self._access.pop(key, None) - result.append((key, value)) - return result +__all__ = ["ModelCache"] diff --git a/runtimes/universal/utils/model_format.py b/runtimes/universal/utils/model_format.py index 6e93e990d..0426736e0 100644 --- a/runtimes/universal/utils/model_format.py +++ b/runtimes/universal/utils/model_format.py @@ -1,21 +1,8 @@ -"""Model format detection utilities for Universal Runtime. - -Detects whether a HuggingFace model repository contains GGUF or transformers format files. - -Note: Core GGUF utilities (list_gguf_files, select_gguf_file, get_gguf_file_path, etc.) -are provided by llamafarm_common and re-exported here for backward compatibility. - -Performance optimizations: -- Results are cached to avoid repeated API calls within a session -- Checks local HuggingFace cache before making network requests -""" - -import logging - -from huggingface_hub import HfApi, scan_cache_dir -from huggingface_hub.utils import HFCacheInfo -from llamafarm_common import ( +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.model_format import ( GGUF_QUANTIZATION_PREFERENCE_ORDER, + clear_format_cache, + detect_model_format, get_gguf_file_path, list_gguf_files, parse_model_with_quantization, @@ -24,15 +11,6 @@ select_gguf_file_with_logging, ) -logger = logging.getLogger(__name__) - -# Cache detection results to avoid repeated filesystem checks -_format_cache: dict[str, str] = {} - -# Cache for local repo info to avoid repeated cache scans -_local_cache_info: HFCacheInfo | None = None - -# Re-export commonly used functions for backward compatibility __all__ = [ "GGUF_QUANTIZATION_PREFERENCE_ORDER", "parse_model_with_quantization", @@ -44,129 +22,3 @@ "get_gguf_file_path", "clear_format_cache", ] - - -def _check_local_cache_for_model(model_id: str) -> list[str] | None: - """Check if model files are available in local HuggingFace cache. - - This avoids making network requests when we can determine format locally. - - Args: - model_id: HuggingFace model identifier - - Returns: - List of cached filenames if model is cached, None otherwise - """ - global _local_cache_info - - try: - # Scan cache once and reuse (scanning is ~10-50ms) - if _local_cache_info is None: - _local_cache_info = scan_cache_dir() - - # Look for this model in cache - for repo in _local_cache_info.repos: - if repo.repo_id == model_id and repo.repo_type == "model": - # Found cached repo - collect all filenames across revisions - filenames = set() - for revision in repo.revisions: - for file in revision.files: - filenames.add(file.file_name) - if filenames: - logger.debug( - f"Found {len(filenames)} files in local cache for {model_id}" - ) - return list(filenames) - - return None - - except Exception as e: - logger.debug(f"Could not scan local cache: {e}") - return None - - -def detect_model_format(model_id: str, token: str | None = None) -> str: - """ - Detect if a HuggingFace model is GGUF or transformers format. - - This function first checks if the model is in the local HuggingFace cache, - and only makes API calls if not cached locally. Results are cached in memory - to avoid repeated checks within a session. - - Args: - model_id: HuggingFace model identifier (e.g., "unsloth/Qwen3-0.6B-GGUF" or "unsloth/Qwen3-0.6B-GGUF:Q4_K_M") - token: Optional HuggingFace authentication token for gated models - - Returns: - "gguf" if model contains .gguf files, "transformers" otherwise - - Raises: - Exception: If model cannot be accessed - - Examples: - >>> detect_model_format("unsloth/Qwen3-0.6B-GGUF") - "gguf" - >>> detect_model_format("unsloth/Qwen3-0.6B-GGUF:Q4_K_M") - "gguf" - >>> detect_model_format("google/gemma-3-1b-it") - "transformers" - """ - # Parse model ID to remove quantization suffix if present - base_model_id, _ = parse_model_with_quantization(model_id) - - # Check memory cache first (fastest) - if base_model_id in _format_cache: - logger.debug( - f"Using cached format for {base_model_id}: {_format_cache[base_model_id]}" - ) - return _format_cache[base_model_id] - - logger.info(f"Detecting format for model: {base_model_id}") - - # Try local cache first to avoid API call - local_files = _check_local_cache_for_model(base_model_id) - if local_files is not None: - has_gguf = any(f.endswith(".gguf") for f in local_files) - if has_gguf: - logger.info("Detected GGUF format from local cache (found .gguf files)") - _format_cache[base_model_id] = "gguf" - return "gguf" - else: - logger.info( - "Detected transformers format from local cache (no .gguf files)" - ) - _format_cache[base_model_id] = "transformers" - return "transformers" - - # Not in local cache - must query API - try: - api = HfApi() - all_files = api.list_repo_files(repo_id=base_model_id, token=token) - - # Check if any .gguf files exist - has_gguf = any(f.endswith(".gguf") for f in all_files) - - if has_gguf: - logger.info("Detected GGUF format (found .gguf files)") - _format_cache[base_model_id] = "gguf" - return "gguf" - - # No GGUF files found - assume transformers format - logger.info("Detected transformers format (no .gguf files found)") - _format_cache[base_model_id] = "transformers" - return "transformers" - - except Exception as e: - logger.error(f"Error detecting model format for {base_model_id}: {e}") - raise - - -def clear_format_cache(): - """Clear the format detection cache. - - Useful for testing or when model repositories are updated. - """ - global _format_cache, _local_cache_info - _format_cache = {} - _local_cache_info = None - logger.debug("Format detection cache cleared") diff --git a/runtimes/universal/utils/safe_home.py b/runtimes/universal/utils/safe_home.py index 28c004c02..b24399a54 100644 --- a/runtimes/universal/utils/safe_home.py +++ b/runtimes/universal/utils/safe_home.py @@ -1,34 +1,4 @@ -"""Safe home directory resolution for embedded Python environments. +"""Re-export from llamafarm_common — single source of truth.""" +from llamafarm_common.safe_home import get_data_dir, safe_home -Path.home() raises RuntimeError in PyApp-embedded Python on Windows -when HOME/USERPROFILE env vars are absent during bootstrap. -""" - -import os -from pathlib import Path - - -def safe_home() -> Path: - """Return the user's home directory with fallback for embedded Python.""" - try: - return Path.home() - except RuntimeError: - fb = ( - os.environ.get("USERPROFILE") - or os.environ.get("APPDATA") - or os.environ.get("LOCALAPPDATA") - ) - if fb: - return Path(fb) - try: - return Path.cwd() - except OSError: - return Path(".") - - -def get_data_dir() -> Path: - """Return the LlamaFarm data directory (LF_DATA_DIR or ~/.llamafarm).""" - env = os.environ.get("LF_DATA_DIR") - if env: - return Path(env) - return safe_home() / ".llamafarm" +__all__ = ["safe_home", "get_data_dir"] diff --git a/runtimes/universal/uv.lock b/runtimes/universal/uv.lock index 3e52ae9e8..1f6cd3162 100644 --- a/runtimes/universal/uv.lock +++ b/runtimes/universal/uv.lock @@ -1981,6 +1981,7 @@ name = "llamafarm-common" version = "0.1.0" source = { editable = "../../common" } dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "hf-transfer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, @@ -1988,6 +1989,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.0.0" }, { name = "filelock", specifier = ">=3.16.1" }, { name = "hf-transfer", specifier = ">=0.1.9" }, { name = "huggingface-hub", specifier = ">=0.24.0" }, @@ -5685,6 +5687,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, diff --git a/server/uv.lock b/server/uv.lock index 221458afb..b03ac5d77 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -234,6 +234,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/87/8bab77b323f16d67be364031220069f79159117dd5e43eeb4be2fef1ac9b/billiard-4.2.4-py3-none-any.whl", hash = "sha256:525b42bdec68d2b983347ac312f892db930858495db601b5836ac24e6477cde5", size = 87070, upload-time = "2025-11-30T13:28:47.016Z" }, ] +[[package]] +name = "cachetools" +version = "7.0.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/dd/57fe3fdb6e65b25a5987fd2cdc7e22db0aef508b91634d2e57d22928d41b/cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990", size = 37367, upload-time = "2026-03-09T20:51:29.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, +] + [[package]] name = "celery" version = "5.6.2" @@ -1046,6 +1055,7 @@ name = "llamafarm-common" version = "0.1.0" source = { editable = "../common" } dependencies = [ + { name = "cachetools", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "hf-transfer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, @@ -1053,6 +1063,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.0.0" }, { name = "filelock", specifier = ">=3.16.1" }, { name = "hf-transfer", specifier = ">=0.1.9" }, { name = "huggingface-hub", specifier = ">=0.24.0" }, @@ -1901,6 +1912,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-discovery" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, + { name = "platformdirs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/90/bcce6b46823c9bec1757c964dc37ed332579be512e17a30e9698095dcae4/python_discovery-1.2.0.tar.gz", hash = "sha256:7d33e350704818b09e3da2bd419d37e21e7c30db6e0977bb438916e06b41b5b1", size = 58055, upload-time = "2026-03-19T01:43:08.248Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/3c/2005227cb951df502412de2fa781f800663cccbef8d90ec6f1b371ac2c0d/python_discovery-1.2.0-py3-none-any.whl", hash = "sha256:1e108f1bbe2ed0ef089823d28805d5ad32be8e734b86a5f212bf89b71c266e4a", size = 31524, upload-time = "2026-03-19T01:43:07.045Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.1" @@ -2529,16 +2553,17 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.37.0" +version = "21.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "filelock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, { name = "platformdirs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, + { name = "python-discovery", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/92/58199fe10049f9703c2666e809c4f686c54ef0a68b0f6afccf518c0b1eb9/virtualenv-21.2.0.tar.gz", hash = "sha256:1720dc3a62ef5b443092e3f499228599045d7fea4c79199770499df8becf9098", size = 5840618, upload-time = "2026-03-09T17:24:38.013Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, ] [[package]]