From 6fe86f074f65e790fc2ce84e31c8d17b2f6f74e1 Mon Sep 17 00:00:00 2001 From: Abinand Nallathambi Date: Sat, 2 Aug 2025 15:43:50 -0400 Subject: [PATCH 1/2] feat: Add code search reranking support via external services - Implement pluggable reranker architecture with HTTP API communication - Add LocalReranker implementation for self-hosted reranking services - Include reference Python reranker service with Docker support - Add comprehensive UI configuration options in settings - Implement proper error handling and fallback mechanisms - Add extensive test coverage for all reranking functionality - Support internationalization for all 18 languages - Update documentation with feature overview and usage instructions The extension now supports optional reranking of code search results through external services, improving search relevance while maintaining user privacy and control. --- README.md | 9 + packages/types/src/codebase-index.ts | 9 + packages/types/src/global-settings.ts | 1 + reranker-service/Dockerfile | 48 ++ reranker-service/README.md | 199 ++++++ reranker-service/app.py | 187 ++++++ reranker-service/config.py | 70 +++ reranker-service/docker-compose.yml | 65 ++ reranker-service/models/__init__.py | 0 reranker-service/models/reranker.py | 204 ++++++ reranker-service/models/schemas.py | 102 +++ reranker-service/requirements.txt | 26 + src/core/webview/ClineProvider.ts | 19 + src/core/webview/webviewMessageHandler.ts | 16 + src/extension.ts | 3 + .../__tests__/config-manager.spec.ts | 582 +++++++++++++++++- .../code-index/__tests__/manager.spec.ts | 3 + .../__tests__/search-service.spec.ts | 542 ++++++++++++++++ src/services/code-index/config-manager.ts | 133 ++++ src/services/code-index/interfaces/config.ts | 7 + src/services/code-index/interfaces/index.ts | 1 + .../code-index/interfaces/reranker.ts | 68 ++ src/services/code-index/manager.ts | 21 +- src/services/code-index/rerankers/base.ts | 103 ++++ src/services/code-index/rerankers/factory.ts | 164 +++++ src/services/code-index/rerankers/index.ts | 11 + src/services/code-index/rerankers/local.ts | 227 +++++++ src/services/code-index/search-service.ts | 95 ++- src/services/code-index/service-factory.ts | 40 +- src/shared/WebviewMessage.ts | 10 + .../code-index/config-manager.test.ts | 394 ++++++++++++ .../code-index/rerankers/factory.test.ts | 465 ++++++++++++++ .../code-index/rerankers/local.test.ts | 581 +++++++++++++++++ .../search-service-reranking.test.ts | 516 ++++++++++++++++ .../src/components/chat/CodeIndexPopover.tsx | 351 ++++++++++- .../src/context/ExtensionStateContext.tsx | 8 + webview-ui/src/i18n/locales/ca/settings.json | 36 +- webview-ui/src/i18n/locales/de/settings.json | 36 +- webview-ui/src/i18n/locales/en/common.json | 2 + webview-ui/src/i18n/locales/en/settings.json | 36 +- webview-ui/src/i18n/locales/es/settings.json | 36 +- webview-ui/src/i18n/locales/fr/settings.json | 37 +- webview-ui/src/i18n/locales/hi/settings.json | 37 +- webview-ui/src/i18n/locales/id/settings.json | 37 +- webview-ui/src/i18n/locales/it/settings.json | 37 +- webview-ui/src/i18n/locales/ja/settings.json | 37 +- webview-ui/src/i18n/locales/ko/settings.json | 37 +- webview-ui/src/i18n/locales/nl/settings.json | 37 +- webview-ui/src/i18n/locales/pl/settings.json | 37 +- .../src/i18n/locales/pt-BR/settings.json | 37 +- webview-ui/src/i18n/locales/ru/settings.json | 37 +- webview-ui/src/i18n/locales/tr/settings.json | 37 +- webview-ui/src/i18n/locales/vi/settings.json | 37 +- .../src/i18n/locales/zh-CN/settings.json | 37 +- .../src/i18n/locales/zh-TW/settings.json | 37 +- 55 files changed, 5881 insertions(+), 63 deletions(-) create mode 100644 reranker-service/Dockerfile create mode 100644 reranker-service/README.md create mode 100644 reranker-service/app.py create mode 100644 reranker-service/config.py create mode 100644 reranker-service/docker-compose.yml create mode 100644 reranker-service/models/__init__.py create mode 100644 reranker-service/models/reranker.py create mode 100644 reranker-service/models/schemas.py create mode 100644 reranker-service/requirements.txt create mode 100644 src/services/code-index/__tests__/search-service.spec.ts create mode 100644 src/services/code-index/interfaces/reranker.ts create mode 100644 src/services/code-index/rerankers/base.ts create mode 100644 src/services/code-index/rerankers/factory.ts create mode 100644 src/services/code-index/rerankers/index.ts create mode 100644 src/services/code-index/rerankers/local.ts create mode 100644 src/tests/services/code-index/config-manager.test.ts create mode 100644 src/tests/services/code-index/rerankers/factory.test.ts create mode 100644 src/tests/services/code-index/rerankers/local.test.ts create mode 100644 src/tests/services/code-index/search-service-reranking.test.ts diff --git a/README.md b/README.md index 08f8f81806..18c189bd66 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,15 @@ Roo Code comes with powerful [tools](https://docs.roocode.com/basic-usage/how-to MCP extends Roo Code's capabilities by allowing you to add unlimited custom tools. Integrate with external APIs, connect to databases, or create specialized development tools - MCP provides the framework to expand Roo Code's functionality to meet your specific needs. +### Search Enhancement + +Roo Code now supports **semantic search reranking** to improve code search results: + +- **What is reranking:** Reranking uses advanced AI models to reorganize search results based on semantic relevance, ensuring the most contextually appropriate code appears first. This dramatically improves the accuracy of code discovery across your project. +- **How to enable it:** Enable reranking through the Roo Code settings by configuring a reranking provider. Once enabled, all codebase searches will automatically benefit from improved result ordering. +- **Supported providers:** Currently supports local (self-hosted) reranking models for privacy and offline use. +- **Setup instructions:** For detailed setup and configuration instructions, see the [reranker service documentation](reranker-service/README.md). + ### Customization Make Roo Code work your way with: diff --git a/packages/types/src/codebase-index.ts b/packages/types/src/codebase-index.ts index 89d5b168d7..556ffccd38 100644 --- a/packages/types/src/codebase-index.ts +++ b/packages/types/src/codebase-index.ts @@ -34,6 +34,14 @@ export const codebaseIndexConfigSchema = z.object({ // OpenAI Compatible specific fields codebaseIndexOpenAiCompatibleBaseUrl: z.string().optional(), codebaseIndexOpenAiCompatibleModelDimension: z.number().optional(), + // Reranker configuration + codebaseIndexRerankerEnabled: z.boolean().optional(), + codebaseIndexRerankerProvider: z.enum(["local", "cohere", "openai", "custom"]).optional(), + codebaseIndexRerankerUrl: z.string().optional(), + codebaseIndexRerankerModel: z.string().optional(), + codebaseIndexRerankerTopN: z.number().min(10).max(500).optional(), + codebaseIndexRerankerTopK: z.number().min(5).max(100).optional(), + codebaseIndexRerankerTimeout: z.number().min(1000).max(30000).optional(), }) export type CodebaseIndexConfig = z.infer @@ -64,6 +72,7 @@ export const codebaseIndexProviderSchema = z.object({ codebaseIndexOpenAiCompatibleModelDimension: z.number().optional(), codebaseIndexGeminiApiKey: z.string().optional(), codebaseIndexMistralApiKey: z.string().optional(), + codebaseIndexRerankerApiKey: z.string().optional(), }) export type CodebaseIndexProvider = z.infer diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 6de4d7413f..0d01291ee3 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -189,6 +189,7 @@ export const SECRET_STATE_KEYS = [ "codebaseIndexOpenAiCompatibleApiKey", "codebaseIndexGeminiApiKey", "codebaseIndexMistralApiKey", + "codebaseIndexRerankerApiKey", "huggingFaceApiKey", "sambaNovaApiKey", ] as const satisfies readonly (keyof ProviderSettings)[] diff --git a/reranker-service/Dockerfile b/reranker-service/Dockerfile new file mode 100644 index 0000000000..94cd741fd0 --- /dev/null +++ b/reranker-service/Dockerfile @@ -0,0 +1,48 @@ +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create cache directory for models +RUN mkdir -p /app/.cache/models + +# Download the model during build to cache it +RUN python -c "from sentence_transformers import CrossEncoder; CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', cache_folder='/app/.cache/models')" + +# Create a non-root user to run the application +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 8080 + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV MODEL_CACHE_DIR=/app/.cache/models + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Run the application +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "1"] \ No newline at end of file diff --git a/reranker-service/README.md b/reranker-service/README.md new file mode 100644 index 0000000000..08ff95c8ba --- /dev/null +++ b/reranker-service/README.md @@ -0,0 +1,199 @@ +# Code Reranker Service + +A FastAPI-based service for reranking code search results using cross-encoder models. This service is designed to improve the relevance of search results in the Roo-Code codebase indexing feature. + +## Overview + +The reranker service uses sentence-transformers with cross-encoder models to rerank code search results based on query-document relevance. It provides a simple REST API that accepts a query and a list of candidate documents, then returns them ordered by relevance. + +## Prerequisites + +- Python 3.10 or higher +- Docker and Docker Compose (for containerized deployment) +- CUDA-capable GPU (optional, for improved performance) + +## Quick Start + +### Using Docker Compose (Recommended) + +1. Navigate to the reranker service directory: + + ```bash + cd reranker-service + ``` + +2. Build and start the service: + + ```bash + docker-compose up --build + ``` + +3. The service will be available at `http://localhost:8080` + +### Using Python Directly + +1. Create a virtual environment: + + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +2. Install dependencies: + + ```bash + pip install -r requirements.txt + ``` + +3. Run the service: + ```bash + uvicorn app:app --host 0.0.0.0 --port 8080 + ``` + +## API Endpoints + +### Health Check + +``` +GET /health +``` + +Returns the service health status and model information. + +### Rerank + +``` +POST /rerank +``` + +Reranks documents based on query relevance. + +**Request Body:** + +```json +{ + "query": "implement user authentication", + "documents": [ + { + "id": "doc1", + "content": "def authenticate_user(username, password):", + "metadata": { + "filePath": "src/auth.py", + "startLine": 10, + "endLine": 20 + } + } + ], + "max_results": 20 +} +``` + +**Response:** + +```json +[ + { + "id": "doc1", + "score": 0.95, + "rank": 1 + } +] +``` + +### API Documentation + +- Swagger UI: `http://localhost:8080/docs` +- ReDoc: `http://localhost:8080/redoc` + +## Configuration + +The service can be configured using environment variables: + +| Variable | Description | Default | +| ----------------- | ---------------------------------------- | -------------------------------------- | +| `MODEL_NAME` | Cross-encoder model to use | `cross-encoder/ms-marco-MiniLM-L-6-v2` | +| `API_PORT` | Port to run the service on | `8080` | +| `API_WORKERS` | Number of worker processes | `1` | +| `REQUEST_TIMEOUT` | Request timeout in seconds | `30` | +| `BATCH_SIZE` | Batch size for model inference | `32` | +| `LOG_LEVEL` | Logging level | `INFO` | +| `FORCE_CPU` | Force CPU usage even if GPU is available | `false` | +| `WARMUP_ON_START` | Warm up model on startup | `true` | + +## Development + +### Running Tests + +```bash +pytest tests/ +``` + +### Building Docker Image + +```bash +docker build -t code-reranker . +``` + +### Development Mode + +For development, you can mount your local code into the container: + +```bash +docker-compose -f docker-compose.yml up +``` + +This will mount the source files as volumes, allowing you to make changes without rebuilding the image. + +## Model Information + +The default model (`cross-encoder/ms-marco-MiniLM-L-6-v2`) is a lightweight cross-encoder optimized for passage reranking. It provides a good balance between performance and accuracy. + +### Supported Models + +- `cross-encoder/ms-marco-MiniLM-L-6-v2` (default) +- `cross-encoder/ms-marco-MiniLM-L-12-v2` (higher accuracy, slower) +- `cross-encoder/ms-marco-TinyBERT-L-2-v2` (faster, lower accuracy) + +## Performance Considerations + +1. **GPU Usage**: The service will automatically use CUDA if available. For CPU-only deployment, set `FORCE_CPU=true`. + +2. **Model Caching**: Models are downloaded and cached in `/app/.cache/models` during the Docker build process. + +3. **Batch Processing**: Adjust `BATCH_SIZE` based on your hardware capabilities and memory constraints. + +4. **Resource Limits**: The Docker Compose configuration sets memory limits (2GB max, 1GB reserved). Adjust these based on your needs. + +## Troubleshooting + +### Service won't start + +- Check logs: `docker-compose logs reranker` +- Ensure port 8080 is not already in use +- Verify Docker daemon is running + +### Out of memory errors + +- Reduce `BATCH_SIZE` +- Increase Docker memory limits in `docker-compose.yml` +- Use a smaller model + +### Slow performance + +- Enable GPU support by ensuring CUDA is available +- Use a smaller model for faster inference +- Increase `API_WORKERS` for parallel processing + +## Next Steps + +This is a placeholder implementation. The actual implementation should: + +1. Integrate the real CrossEncoder model from sentence-transformers +2. Add proper error handling and validation +3. Implement request queuing for high load +4. Add metrics and monitoring +5. Implement model versioning and updates + +## License + +This service is part of the Roo-Code project. diff --git a/reranker-service/app.py b/reranker-service/app.py new file mode 100644 index 0000000000..3a24bbbdc0 --- /dev/null +++ b/reranker-service/app.py @@ -0,0 +1,187 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from typing import List, Optional +import logging +import time +from datetime import datetime, timezone + +from models.reranker import CrossEncoderReranker +from models.schemas import RerankRequest, RerankResponse, HealthResponse +from config import LOG_LEVEL, LOG_FORMAT, validate_config + +# Configure logging +logging.basicConfig( + level=getattr(logging, LOG_LEVEL), + format=LOG_FORMAT +) +logger = logging.getLogger(__name__) + +# Track startup time for uptime calculation +startup_time = time.time() + +# Initialize FastAPI app +app = FastAPI( + title="Code Reranker API", + version="1.0.0", + description="A FastAPI service for reranking code search results using cross-encoder models" +) + +# Configure CORS middleware for localhost +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:*", + "http://127.0.0.1:*", + "http://0.0.0.0:*" + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Validate configuration on startup +try: + validate_config() + logger.info("Configuration validated successfully") +except Exception as e: + logger.error(f"Configuration validation failed: {str(e)}") + raise + +# Initialize reranker +try: + reranker = CrossEncoderReranker() + logger.info("Reranker initialized successfully") +except Exception as e: + logger.error(f"Failed to initialize reranker: {str(e)}") + reranker = None + +@app.get("/") +async def root(): + """Root endpoint providing API information""" + return { + "name": "Code Reranker API", + "version": "1.0.0", + "status": "online", + "endpoints": { + "health": "/health", + "rerank": "/rerank", + "docs": "/docs" + } + } + +@app.post("/rerank", response_model=List[RerankResponse]) +async def rerank(request: RerankRequest): + """ + Rerank code search results based on query relevance. + + Args: + request: RerankRequest containing query, documents, and max_results + + Returns: + List of RerankResponse objects with id, score, and rank + """ + if not reranker: + raise HTTPException( + status_code=503, + detail="Reranker service is not available" + ) + + try: + # Validate request + if not request.query: + raise HTTPException( + status_code=400, + detail="Query cannot be empty" + ) + + if not request.documents: + raise HTTPException( + status_code=400, + detail="Documents list cannot be empty" + ) + + # Convert Document objects to dictionaries for reranker + documents = [doc.model_dump() for doc in request.documents] + + # Perform reranking + results = await reranker.rerank( + query=request.query, + documents=documents, + max_results=request.max_results or 20 + ) + + return results + + except ValueError as e: + logger.error(f"Invalid request: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Reranking error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + +@app.get("/health") +async def health(): + """ + Health check endpoint to verify service status. + + Returns: + Extended health information including uptime + """ + # Calculate uptime + current_time = time.time() + uptime_seconds = int(current_time - startup_time) + uptime_hours = uptime_seconds // 3600 + uptime_minutes = (uptime_seconds % 3600) // 60 + uptime_str = f"{uptime_hours}h {uptime_minutes}m {uptime_seconds % 60}s" + + if not reranker: + return { + "status": "unhealthy", + "model": "not loaded", + "device": "unknown", + "uptime": uptime_str, + "timestamp": datetime.now(timezone.utc).isoformat(), + "error": "Reranker not initialized" + } + + try: + # Perform model validation + model_valid = reranker.validate_model() + model_name = getattr(reranker, 'model_name', 'unknown') + device = getattr(reranker, 'device', 'unknown') + + return { + "status": "healthy" if model_valid else "degraded", + "model": model_name, + "device": device, + "uptime": uptime_str, + "timestamp": datetime.now(timezone.utc).isoformat(), + "model_valid": model_valid + } + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + return { + "status": "unhealthy", + "model": "error", + "device": "error", + "uptime": uptime_str, + "timestamp": datetime.now(timezone.utc).isoformat(), + "error": str(e) + } + +# Add startup and shutdown events +@app.on_event("startup") +async def startup_event(): + """Initialize resources on startup""" + logger.info("Code Reranker API starting up...") + logger.info(f"Model: {reranker.model_name if reranker else 'Not loaded'}") + logger.info(f"Device: {reranker.device if reranker else 'Unknown'}") + +@app.on_event("shutdown") +async def shutdown_event(): + """Clean up resources on shutdown""" + logger.info("Code Reranker API shutting down...") + if reranker: + # Cleanup will be handled by the reranker's __del__ method + logger.info("Cleaning up reranker resources...") \ No newline at end of file diff --git a/reranker-service/config.py b/reranker-service/config.py new file mode 100644 index 0000000000..152dea8cc9 --- /dev/null +++ b/reranker-service/config.py @@ -0,0 +1,70 @@ +""" +Configuration constants and settings for the reranker service. +""" + +import os +from typing import Optional + +# Model configuration +DEFAULT_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" +MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "/app/.cache/models") + +# API configuration +API_HOST = os.getenv("API_HOST", "0.0.0.0") +API_PORT = int(os.getenv("API_PORT", "8080")) +API_WORKERS = int(os.getenv("API_WORKERS", "1")) + +# Reranking configuration +DEFAULT_MAX_RESULTS = 20 +MAX_ALLOWED_RESULTS = 100 +MIN_ALLOWED_RESULTS = 1 +MAX_DOCUMENT_LENGTH = 10000 # Maximum characters per document + +# Performance configuration +REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "30")) # seconds +BATCH_SIZE = int(os.getenv("BATCH_SIZE", "32")) + +# Logging configuration +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +# CORS configuration +CORS_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:*,http://127.0.0.1:*").split(",") + +# Device configuration +FORCE_CPU = os.getenv("FORCE_CPU", "false").lower() == "true" + +# Model warmup configuration +WARMUP_ON_START = os.getenv("WARMUP_ON_START", "true").lower() == "true" + +# Health check configuration +HEALTH_CHECK_TIMEOUT = 5 # seconds + +def get_model_name() -> str: + """Get the model name from environment or use default.""" + return os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME) + +def get_device() -> str: + """Determine the device to use for model inference.""" + if FORCE_CPU: + return "cpu" + + try: + import torch + return "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + return "cpu" + +def validate_config() -> None: + """Validate configuration settings.""" + if API_PORT < 1 or API_PORT > 65535: + raise ValueError(f"Invalid API_PORT: {API_PORT}") + + if API_WORKERS < 1: + raise ValueError(f"Invalid API_WORKERS: {API_WORKERS}") + + if REQUEST_TIMEOUT < 1: + raise ValueError(f"Invalid REQUEST_TIMEOUT: {REQUEST_TIMEOUT}") + + if BATCH_SIZE < 1: + raise ValueError(f"Invalid BATCH_SIZE: {BATCH_SIZE}") \ No newline at end of file diff --git a/reranker-service/docker-compose.yml b/reranker-service/docker-compose.yml new file mode 100644 index 0000000000..45f5307de3 --- /dev/null +++ b/reranker-service/docker-compose.yml @@ -0,0 +1,65 @@ +version: '3.8' + +services: + reranker: + build: + context: . + dockerfile: Dockerfile + container_name: code-reranker + ports: + - "8080:8080" + environment: + # Model configuration + - MODEL_NAME=cross-encoder/ms-marco-MiniLM-L-6-v2 + - MODEL_CACHE_DIR=/app/.cache/models + + # API configuration + - API_HOST=0.0.0.0 + - API_PORT=8080 + - API_WORKERS=1 + + # Performance configuration + - REQUEST_TIMEOUT=30 + - BATCH_SIZE=32 + + # Logging + - LOG_LEVEL=INFO + + # CORS configuration + - CORS_ORIGINS=http://localhost:*,http://127.0.0.1:* + + # Device configuration (set to true to force CPU usage) + - FORCE_CPU=false + + # Model warmup + - WARMUP_ON_START=true + + volumes: + # Mount model cache to persist downloaded models + - model-cache:/app/.cache/models + + # For development: mount source code + - ./app.py:/app/app.py:ro + - ./models:/app/models:ro + - ./config.py:/app/config.py:ro + + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + restart: unless-stopped + + # Resource limits + deploy: + resources: + limits: + memory: 2G + reservations: + memory: 1G + +volumes: + model-cache: + driver: local \ No newline at end of file diff --git a/reranker-service/models/__init__.py b/reranker-service/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reranker-service/models/reranker.py b/reranker-service/models/reranker.py new file mode 100644 index 0000000000..d388ca5a66 --- /dev/null +++ b/reranker-service/models/reranker.py @@ -0,0 +1,204 @@ +from typing import List, Dict, Any +import asyncio +import logging +from sentence_transformers import CrossEncoder +import torch +from concurrent.futures import ThreadPoolExecutor +import time + +from config import ( + get_model_name, + get_device, + BATCH_SIZE, + MODEL_CACHE_DIR, + WARMUP_ON_START +) + +logger = logging.getLogger(__name__) + + +class CrossEncoderReranker: + """ + Cross-encoder based reranker using sentence-transformers. + + This class provides reranking functionality for code search results + using pre-trained cross-encoder models. + """ + + def __init__(self, model_name: str = None): + """ + Initialize the reranker with specified model. + + Args: + model_name: Name of the cross-encoder model to use. + If None, uses the model from config. + """ + self.model_name = model_name or get_model_name() + self.device = get_device() + self._executor = ThreadPoolExecutor(max_workers=1) + + logger.info(f"Initializing reranker with model: {self.model_name}") + logger.info(f"Device: {self.device}") + logger.info(f"Model cache directory: {MODEL_CACHE_DIR}") + + try: + # Initialize CrossEncoder model with caching + self.model = CrossEncoder( + model_name=self.model_name, + device=self.device, + max_length=512, # Maximum sequence length + trust_remote_code=False + ) + logger.info(f"Successfully loaded model: {self.model_name}") + + # Perform warmup if configured + if WARMUP_ON_START: + asyncio.create_task(self.warmup()) + + except Exception as e: + logger.error(f"Failed to load model {self.model_name}: {str(e)}") + raise + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + max_results: int = 20 + ) -> List[Dict[str, Any]]: + """ + Rerank documents based on query relevance using cross-encoder. + + Args: + query: The search query + documents: List of documents to rerank, each with 'id' and 'content' + max_results: Maximum number of results to return + + Returns: + List of reranked documents with scores and ranks + """ + if not documents: + return [] + + start_time = time.time() + logger.info(f"Reranking {len(documents)} documents for query: '{query}'") + + try: + # Extract content and create query-document pairs + pairs = [[query, doc["content"]] for doc in documents] + + # Run scoring in executor to avoid blocking + loop = asyncio.get_event_loop() + scores = await loop.run_in_executor( + self._executor, + self._score_pairs, + pairs + ) + + # Create result objects with scores + scored_docs = [] + for i, (doc, score) in enumerate(zip(documents, scores)): + scored_docs.append({ + "id": doc["id"], + "score": float(score), + "rank": 0 # Will be assigned after sorting + }) + + # Sort by score in descending order + scored_docs.sort(key=lambda x: x["score"], reverse=True) + + # Assign ranks and limit results + results = [] + for i, doc in enumerate(scored_docs[:max_results]): + doc["rank"] = i + 1 + results.append(doc) + + elapsed_time = time.time() - start_time + logger.info( + f"Reranking complete in {elapsed_time:.2f}s. " + f"Returning {len(results)} results" + ) + + return results + + except Exception as e: + logger.error(f"Error during reranking: {str(e)}") + raise + + def _score_pairs(self, pairs: List[List[str]]) -> List[float]: + """ + Score query-document pairs using the cross-encoder model. + + This method is called in a separate thread to avoid blocking. + + Args: + pairs: List of [query, document] pairs + + Returns: + List of relevance scores + """ + try: + # Process in batches if needed + all_scores = [] + + for i in range(0, len(pairs), BATCH_SIZE): + batch = pairs[i:i + BATCH_SIZE] + batch_scores = self.model.predict(batch) + all_scores.extend(batch_scores) + + return all_scores + + except Exception as e: + logger.error(f"Error scoring pairs: {str(e)}") + raise + + def validate_model(self) -> bool: + """ + Validate that the model is properly loaded. + + Returns: + True if model is valid, False otherwise + """ + try: + # Check if model is loaded + if self.model is None: + return False + + # Try a simple prediction to ensure model works + test_pairs = [["test", "test"]] + scores = self.model.predict(test_pairs) + + return len(scores) == 1 and isinstance(scores[0], (float, int)) + + except Exception as e: + logger.error(f"Model validation failed: {str(e)}") + return False + + async def warmup(self): + """ + Warm up the model with a sample query. + This helps ensure the model is ready for production use. + """ + logger.info("Warming up reranker model...") + + try: + sample_docs = [ + { + "id": "warmup1", + "content": "def authenticate_user(username, password): return True" + }, + { + "id": "warmup2", + "content": "class UserAuth: def login(self, user, pwd): pass" + } + ] + + results = await self.rerank("user authentication", sample_docs, max_results=2) + logger.info(f"Warmup complete. Processed {len(results)} results") + + except Exception as e: + logger.error(f"Warmup failed: {str(e)}") + + def __del__(self): + """Clean up resources when the reranker is destroyed.""" + if hasattr(self, '_executor'): + self._executor.shutdown(wait=False) \ No newline at end of file diff --git a/reranker-service/models/schemas.py b/reranker-service/models/schemas.py new file mode 100644 index 0000000000..ac00d8fc24 --- /dev/null +++ b/reranker-service/models/schemas.py @@ -0,0 +1,102 @@ +from pydantic import BaseModel, Field +from typing import List, Dict, Any, Optional + + +class Document(BaseModel): + """Document model for reranking""" + id: str = Field(..., description="Unique identifier for the document") + content: str = Field(..., description="The text content of the document") + metadata: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional metadata about the document" + ) + + +class RerankRequest(BaseModel): + """Request model for reranking endpoint""" + query: str = Field( + ..., + description="The search query to use for reranking", + min_length=1 + ) + documents: List[Document] = Field( + ..., + description="List of documents to rerank", + min_items=1 + ) + max_results: Optional[int] = Field( + default=20, + description="Maximum number of results to return", + ge=1, + le=100 + ) + + class Config: + json_schema_extra = { + "example": { + "query": "implement user authentication", + "documents": [ + { + "id": "doc1", + "content": "def authenticate_user(username, password):\n # Implementation here", + "metadata": { + "filePath": "src/auth.py", + "startLine": 10, + "endLine": 20 + } + }, + { + "id": "doc2", + "content": "class UserAuth:\n def login(self, user, pass):", + "metadata": { + "filePath": "src/models/user.py", + "startLine": 45, + "endLine": 50 + } + } + ], + "max_results": 10 + } + } + + +class RerankResponse(BaseModel): + """Response model for reranking results""" + id: str = Field(..., description="Document identifier") + score: float = Field(..., description="Relevance score from the reranker") + rank: int = Field(..., description="Rank position (1-based)") + + class Config: + json_schema_extra = { + "example": { + "id": "doc1", + "score": 0.95, + "rank": 1 + } + } + + +class HealthResponse(BaseModel): + """Response model for health check endpoint""" + status: str = Field( + ..., + description="Health status of the service", + pattern="^(healthy|unhealthy)$" + ) + model: str = Field( + ..., + description="Name of the loaded model" + ) + device: str = Field( + ..., + description="Device being used (cpu/cuda)" + ) + + class Config: + json_schema_extra = { + "example": { + "status": "healthy", + "model": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "device": "cuda" + } + } \ No newline at end of file diff --git a/reranker-service/requirements.txt b/reranker-service/requirements.txt new file mode 100644 index 0000000000..cb161ffac3 --- /dev/null +++ b/reranker-service/requirements.txt @@ -0,0 +1,26 @@ +# Core FastAPI dependencies +fastapi==0.109.0 +uvicorn[standard]==0.27.0 +pydantic==2.5.3 +python-multipart==0.0.6 + +# ML/Reranking dependencies +sentence-transformers==2.2.2 +torch==2.1.2 +transformers==4.36.2 + +# Additional dependencies +numpy==1.24.3 +scipy==1.11.4 +scikit-learn==1.3.2 + +# Async support +aiofiles==23.2.1 + +# Monitoring and logging +python-json-logger==2.0.7 + +# Development dependencies (optional) +pytest==7.4.4 +pytest-asyncio==0.23.3 +httpx==0.26.0 \ No newline at end of file diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index ed8f8a27d1..442d43263e 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -1748,6 +1748,14 @@ export class ClineProvider codebaseIndexOpenAiCompatibleBaseUrl: codebaseIndexConfig?.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: codebaseIndexConfig?.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: codebaseIndexConfig?.codebaseIndexSearchMinScore, + // Reranker settings + codebaseIndexRerankerEnabled: codebaseIndexConfig?.codebaseIndexRerankerEnabled ?? false, + codebaseIndexRerankerProvider: codebaseIndexConfig?.codebaseIndexRerankerProvider ?? "local", + codebaseIndexRerankerUrl: codebaseIndexConfig?.codebaseIndexRerankerUrl ?? "http://localhost:8003", + codebaseIndexRerankerModel: codebaseIndexConfig?.codebaseIndexRerankerModel ?? "Qwen/Qwen3-Reranker-8B", + codebaseIndexRerankerTopN: codebaseIndexConfig?.codebaseIndexRerankerTopN ?? 100, + codebaseIndexRerankerTopK: codebaseIndexConfig?.codebaseIndexRerankerTopK ?? 20, + codebaseIndexRerankerTimeout: codebaseIndexConfig?.codebaseIndexRerankerTimeout ?? 10000, }, mdmCompliant: this.checkMdmCompliance(), profileThresholds: profileThresholds ?? {}, @@ -1938,6 +1946,17 @@ export class ClineProvider stateValues.codebaseIndexConfig?.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: stateValues.codebaseIndexConfig?.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: stateValues.codebaseIndexConfig?.codebaseIndexSearchMinScore, + // Reranker settings + codebaseIndexRerankerEnabled: stateValues.codebaseIndexConfig?.codebaseIndexRerankerEnabled ?? false, + codebaseIndexRerankerProvider: + stateValues.codebaseIndexConfig?.codebaseIndexRerankerProvider ?? "local", + codebaseIndexRerankerUrl: + stateValues.codebaseIndexConfig?.codebaseIndexRerankerUrl ?? "http://localhost:8003", + codebaseIndexRerankerModel: + stateValues.codebaseIndexConfig?.codebaseIndexRerankerModel ?? "Qwen/Qwen3-Reranker-8B", + codebaseIndexRerankerTopN: stateValues.codebaseIndexConfig?.codebaseIndexRerankerTopN ?? 100, + codebaseIndexRerankerTopK: stateValues.codebaseIndexConfig?.codebaseIndexRerankerTopK ?? 20, + codebaseIndexRerankerTimeout: stateValues.codebaseIndexConfig?.codebaseIndexRerankerTimeout ?? 10000, }, profileThresholds: stateValues.profileThresholds ?? {}, // Add diagnostic message settings diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index fdb7e90425..2cd0bc2daa 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -2016,6 +2016,14 @@ export const webviewMessageHandler = async ( codebaseIndexOpenAiCompatibleBaseUrl: settings.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: settings.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: settings.codebaseIndexSearchMinScore, + // Reranker settings + codebaseIndexRerankerEnabled: settings.codebaseIndexRerankerEnabled, + codebaseIndexRerankerProvider: settings.codebaseIndexRerankerProvider, + codebaseIndexRerankerUrl: settings.codebaseIndexRerankerUrl, + codebaseIndexRerankerModel: settings.codebaseIndexRerankerModel, + codebaseIndexRerankerTopN: settings.codebaseIndexRerankerTopN, + codebaseIndexRerankerTopK: settings.codebaseIndexRerankerTopK, + codebaseIndexRerankerTimeout: settings.codebaseIndexRerankerTimeout, } // Save global state first @@ -2046,6 +2054,12 @@ export const webviewMessageHandler = async ( settings.codebaseIndexMistralApiKey, ) } + if (settings.codebaseIndexRerankerApiKey !== undefined) { + await provider.contextProxy.storeSecret( + "codebaseIndexRerankerApiKey", + settings.codebaseIndexRerankerApiKey, + ) + } // Send success response first - settings are saved regardless of validation await provider.postMessageToWebview({ @@ -2167,6 +2181,7 @@ export const webviewMessageHandler = async ( )) const hasGeminiApiKey = !!(await provider.context.secrets.get("codebaseIndexGeminiApiKey")) const hasMistralApiKey = !!(await provider.context.secrets.get("codebaseIndexMistralApiKey")) + const hasRerankerApiKey = !!(await provider.context.secrets.get("codebaseIndexRerankerApiKey")) provider.postMessageToWebview({ type: "codeIndexSecretStatus", @@ -2176,6 +2191,7 @@ export const webviewMessageHandler = async ( hasOpenAiCompatibleApiKey, hasGeminiApiKey, hasMistralApiKey, + hasRerankerApiKey, }, }) break diff --git a/src/extension.ts b/src/extension.ts index 60c61aada7..389b9f3e63 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -74,6 +74,9 @@ export async function activate(context: vscode.ExtensionContext) { // Create logger for cloud services const cloudLogger = createDualLogger(createOutputChannelLogger(outputChannel)) + // Set logger for CodeIndexManager + CodeIndexManager.setLogger(cloudLogger) + // Initialize Roo Code Cloud service. const cloudService = await CloudService.createInstance(context, cloudLogger) const postStateListener = () => { diff --git a/src/services/code-index/__tests__/config-manager.spec.ts b/src/services/code-index/__tests__/config-manager.spec.ts index 2d6e704d76..ae6624fbab 100644 --- a/src/services/code-index/__tests__/config-manager.spec.ts +++ b/src/services/code-index/__tests__/config-manager.spec.ts @@ -1292,14 +1292,26 @@ describe("CodeIndexConfigManager", () => { isConfigured: true, embedderProvider: "openai", modelId: "text-embedding-3-large", + modelDimension: undefined, openAiOptions: { openAiNativeApiKey: "test-openai-key" }, ollamaOptions: { ollamaBaseUrl: undefined }, - geminiOptions: undefined, openAiCompatibleOptions: undefined, + geminiOptions: undefined, + mistralOptions: undefined, qdrantUrl: "http://qdrant.local", qdrantApiKey: "test-qdrant-key", searchMinScore: 0.4, searchMaxResults: 50, + rerankerConfig: { + enabled: false, + provider: "local", + url: "http://localhost:8003", + apiKey: "", + model: "Qwen/Qwen3-Reranker-8B", + topN: 100, + topK: 20, + timeout: 10000, + }, }) }) @@ -1810,5 +1822,573 @@ describe("CodeIndexConfigManager", () => { expect(configManager.currentModelDimension).toBe(undefined) }) }) + + describe("currentSearchMinScore", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should return user-configured score when set", async () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexSearchMinScore: 0.7, + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + await configManager.loadConfiguration() + + expect(configManager.currentSearchMinScore).toBe(0.7) + }) + + it("should return model-specific threshold when user score not set", async () => { + // Mock getModelScoreThreshold to return a model-specific threshold + mockedGetModelScoreThreshold.mockReturnValue(0.6) + + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexEmbedderModelId: "text-embedding-3-small", + // No searchMinScore set + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + await configManager.loadConfiguration() + + expect(configManager.currentSearchMinScore).toBe(0.6) + expect(mockedGetModelScoreThreshold).toHaveBeenCalledWith("openai", "text-embedding-3-small") + }) + + it("should return default score when neither user nor model threshold is available", async () => { + // Mock getModelScoreThreshold to return undefined + mockedGetModelScoreThreshold.mockReturnValue(undefined) + + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexEmbedderProvider: "openai", + // No searchMinScore set + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + await configManager.loadConfiguration() + + // Should return DEFAULT_SEARCH_MIN_SCORE (0.4) + expect(configManager.currentSearchMinScore).toBe(0.4) + }) + }) + + describe("currentSearchMaxResults", () => { + it("should return user-configured max results when set", async () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexSearchMaxResults: 50, + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + await configManager.loadConfiguration() + + expect(configManager.currentSearchMaxResults).toBe(50) + }) + + it("should return default max results when not configured", async () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexEmbedderProvider: "openai", + // No searchMaxResults set + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + await configManager.loadConfiguration() + + // Should return DEFAULT_MAX_SEARCH_RESULTS (50) + expect(configManager.currentSearchMaxResults).toBe(50) + }) + }) + }) + + describe("Reranker Configuration", () => { + describe("isRerankerEnabled", () => { + it("should return true when reranker is enabled and feature is enabled", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + expect(configManager.isRerankerEnabled).toBe(true) + }) + + it("should return false when reranker is disabled", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: false, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + expect(configManager.isRerankerEnabled).toBe(false) + }) + + it("should return false when main feature is disabled even if reranker is enabled", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: false, + codebaseIndexRerankerEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + expect(configManager.isRerankerEnabled).toBe(false) + }) + }) + + describe("getRerankerConfig", () => { + it("should return complete reranker configuration with custom values", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "remote", + codebaseIndexRerankerUrl: "https://api.reranker.com", + codebaseIndexRerankerModel: "custom-reranker-model", + codebaseIndexRerankerTopN: 150, + codebaseIndexRerankerTopK: 30, + codebaseIndexRerankerTimeout: 15000, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + if (key === "codebaseIndexRerankerApiKey") return "reranker-api-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + const rerankerConfig = configManager.getRerankerConfig() + + expect(rerankerConfig).toEqual({ + enabled: true, + provider: "remote", + url: "https://api.reranker.com", + apiKey: "reranker-api-key", + model: "custom-reranker-model", + topN: 150, + topK: 30, + timeout: 15000, + }) + }) + + it("should return default reranker configuration when not configured", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + // No reranker settings provided + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + const rerankerConfig = configManager.getRerankerConfig() + + expect(rerankerConfig).toEqual({ + enabled: false, + provider: "local", + url: "http://localhost:8003", + apiKey: "", + model: "Qwen/Qwen3-Reranker-8B", + topN: 100, + topK: 20, + timeout: 10000, + }) + }) + }) + + describe("rerankerTopN and rerankerTopK getters", () => { + it("should return custom topN and topK values", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerTopN: 200, + codebaseIndexRerankerTopK: 40, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + + expect(configManager.rerankerTopN).toBe(200) + expect(configManager.rerankerTopK).toBe(40) + }) + + it("should return default topN and topK values when not configured", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + // No topN/topK configured + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + + expect(configManager.rerankerTopN).toBe(100) + expect(configManager.rerankerTopK).toBe(20) + }) + }) + + describe("Reranker configuration loading", () => { + it("should load reranker configuration from globalState with all fields", async () => { + const mockGlobalState = { + codebaseIndexEnabled: true, + codebaseIndexQdrantUrl: "http://qdrant.local", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "remote", + codebaseIndexRerankerUrl: "https://reranker.example.com", + codebaseIndexRerankerModel: "advanced-reranker", + codebaseIndexRerankerTopN: 250, + codebaseIndexRerankerTopK: 50, + codebaseIndexRerankerTimeout: 20000, + } + mockContextProxy.getGlobalState.mockReturnValue(mockGlobalState) + + setupSecretMocks({ + codeIndexOpenAiKey: "test-openai-key", + codeIndexQdrantApiKey: "test-qdrant-key", + codebaseIndexRerankerApiKey: "test-reranker-key", + }) + + const result = await configManager.loadConfiguration() + const rerankerConfig = configManager.getRerankerConfig() + + expect(rerankerConfig).toEqual({ + enabled: true, + provider: "remote", + url: "https://reranker.example.com", + apiKey: "test-reranker-key", + model: "advanced-reranker", + topN: 250, + topK: 50, + timeout: 20000, + }) + }) + + it("should use default values for missing reranker fields", async () => { + const mockGlobalState = { + codebaseIndexEnabled: true, + codebaseIndexQdrantUrl: "http://qdrant.local", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexRerankerEnabled: true, + // Missing other reranker fields + } + mockContextProxy.getGlobalState.mockReturnValue(mockGlobalState) + + setupSecretMocks({ + codeIndexOpenAiKey: "test-openai-key", + }) + + const result = await configManager.loadConfiguration() + const rerankerConfig = configManager.getRerankerConfig() + + expect(rerankerConfig).toEqual({ + enabled: true, + provider: "local", + url: "http://localhost:8003", + apiKey: "", + model: "Qwen/Qwen3-Reranker-8B", + topN: 100, + topK: 20, + timeout: 10000, + }) + }) + }) + + describe("Reranker state change detection", () => { + it("should require restart when reranker is enabled", async () => { + // Initial state: reranker disabled + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: false, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + // Get initial state + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Enable reranker + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(true) + }) + + it("should require restart when reranker provider changes", async () => { + // Initial state: local reranker + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "local", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Change to remote provider + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "remote", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(true) + }) + + it("should require restart when reranker URL changes", async () => { + // Initial state + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerUrl: "http://localhost:8003", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Change URL + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerUrl: "https://new-reranker.com", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(true) + }) + + it("should require restart when reranker model changes", async () => { + // Initial state + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerModel: "model-v1", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Change model + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerModel: "model-v2", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(true) + }) + + it("should require restart when reranker API key changes", async () => { + // Initial state + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + setupSecretMocks({ + codeIndexOpenAiKey: "test-key", + codebaseIndexRerankerApiKey: "old-api-key", + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Change API key + setupSecretMocks({ + codeIndexOpenAiKey: "test-key", + codebaseIndexRerankerApiKey: "new-api-key", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(true) + }) + + it("should not require restart when only reranker topN/topK changes", async () => { + // Initial state + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerTopN: 100, + codebaseIndexRerankerTopK: 20, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + configManager = new CodeIndexConfigManager(mockContextProxy) + + const { configSnapshot: previousSnapshot } = await configManager.loadConfiguration() + + // Change only topN and topK + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerTopN: 150, + codebaseIndexRerankerTopK: 30, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + + const { requiresRestart } = await configManager.loadConfiguration() + expect(requiresRestart).toBe(false) + }) + }) + + describe("Reranker and feature integration", () => { + it("should include reranker config in getConfig() output", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: true, + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "local", + codebaseIndexRerankerUrl: "http://localhost:8003", + codebaseIndexRerankerModel: "test-model", + codebaseIndexRerankerTopN: 80, + codebaseIndexRerankerTopK: 15, + codebaseIndexRerankerTimeout: 5000, + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + if (key === "codebaseIndexRerankerApiKey") return "reranker-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + const config = configManager.getConfig() + + expect(config.rerankerConfig).toBeDefined() + expect(config.rerankerConfig).toEqual({ + enabled: true, + provider: "local", + url: "http://localhost:8003", + apiKey: "reranker-key", + model: "test-model", + topN: 80, + topK: 15, + timeout: 5000, + }) + }) + + it("should disable reranker when main feature is disabled", () => { + mockContextProxy.getGlobalState.mockReturnValue({ + codebaseIndexEnabled: false, // Main feature disabled + codebaseIndexRerankerEnabled: true, // Reranker enabled + codebaseIndexEmbedderProvider: "openai", + codebaseIndexQdrantUrl: "http://localhost:6333", + }) + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codeIndexOpenAiKey") return "test-key" + return undefined + }) + + configManager = new CodeIndexConfigManager(mockContextProxy) + + // isRerankerEnabled should be false + expect(configManager.isRerankerEnabled).toBe(false) + + // But the config should still show the actual state + const rerankerConfig = configManager.getRerankerConfig() + expect(rerankerConfig.enabled).toBe(true) // Actual config state + }) + }) }) }) diff --git a/src/services/code-index/__tests__/manager.spec.ts b/src/services/code-index/__tests__/manager.spec.ts index 8c64c2fdc6..272aa34111 100644 --- a/src/services/code-index/__tests__/manager.spec.ts +++ b/src/services/code-index/__tests__/manager.spec.ts @@ -140,6 +140,7 @@ describe("CodeIndexManager - handleSettingsChange regression", () => { }, }), validateEmbedder: vi.fn().mockResolvedValue({ valid: true }), + createReranker: vi.fn().mockResolvedValue(null), } MockedCodeIndexServiceFactory.mockImplementation(() => mockServiceFactoryInstance as any) @@ -214,6 +215,7 @@ describe("CodeIndexManager - handleSettingsChange regression", () => { }, }), validateEmbedder: vi.fn().mockResolvedValue({ valid: true }), + createReranker: vi.fn().mockResolvedValue(null), } MockedCodeIndexServiceFactory.mockImplementation(() => mockServiceFactoryInstance as any) @@ -267,6 +269,7 @@ describe("CodeIndexManager - handleSettingsChange regression", () => { scanner: mockScanner, fileWatcher: mockFileWatcher, }), + createReranker: vi.fn().mockResolvedValue(undefined), validateEmbedder: vi.fn(), } diff --git a/src/services/code-index/__tests__/search-service.spec.ts b/src/services/code-index/__tests__/search-service.spec.ts new file mode 100644 index 0000000000..ae371e18ed --- /dev/null +++ b/src/services/code-index/__tests__/search-service.spec.ts @@ -0,0 +1,542 @@ +// npx vitest services/code-index/__tests__/search-service.spec.ts + +import { describe, it, expect, beforeEach, vi, Mock } from "vitest" +import { CodeIndexSearchService } from "../search-service" +import { CodeIndexConfigManager } from "../config-manager" +import { CodeIndexStateManager } from "../state-manager" +import { IEmbedder } from "../interfaces/embedder" +import { IVectorStore, VectorStoreSearchResult } from "../interfaces/vector-store" +import { IReranker, RerankCandidate, RerankResult } from "../interfaces/reranker" +import { TelemetryService } from "@roo-code/telemetry" +import { TelemetryEventName } from "@roo-code/types" + +// Mock dependencies +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureEvent: vi.fn(), + }, + }, +})) + +describe("CodeIndexSearchService", () => { + let service: CodeIndexSearchService + let mockConfigManager: CodeIndexConfigManager + let mockStateManager: CodeIndexStateManager + let mockEmbedder: IEmbedder + let mockVectorStore: IVectorStore + let mockReranker: IReranker + let mockLogger: Mock + + // Sample test data + const testQuery = "find authentication logic" + const testEmbedding = [0.1, 0.2, 0.3, 0.4, 0.5] + const testVectorResults: VectorStoreSearchResult[] = [ + { + id: "1", + score: 0.9, + payload: { + filePath: "auth.ts", + codeChunk: "function authenticate()", + startLine: 10, + endLine: 20, + }, + }, + { + id: "2", + score: 0.8, + payload: { + filePath: "login.ts", + codeChunk: "function login()", + startLine: 5, + endLine: 15, + }, + }, + { + id: "3", + score: 0.7, + payload: { + filePath: "user.ts", + codeChunk: "class User", + startLine: 1, + endLine: 50, + }, + }, + ] + + const testRerankedResults: RerankResult[] = [ + { id: "2", score: 0.95 }, // login.ts reranked higher + { id: "1", score: 0.85 }, // auth.ts second + // user.ts filtered out by topK + ] + + beforeEach(() => { + vi.clearAllMocks() + + // Setup mocks + mockLogger = vi.fn() + + // Create mockConfigManager with getter properties that can be modified + mockConfigManager = { + get isFeatureEnabled() { + return true + }, + get isFeatureConfigured() { + return true + }, + get isRerankerEnabled() { + return true + }, + get currentSearchMinScore() { + return 0.5 + }, + get currentSearchMaxResults() { + return 10 + }, + get rerankerTopN() { + return 20 + }, + get rerankerTopK() { + return 5 + }, + } as any + + mockStateManager = { + getCurrentStatus: vi.fn().mockReturnValue({ + systemStatus: "Indexed", + }), + setSystemState: vi.fn(), + } as any + + mockEmbedder = { + createEmbeddings: vi.fn().mockResolvedValue({ + embeddings: [testEmbedding], + }), + } as any + + mockVectorStore = { + search: vi.fn().mockResolvedValue(testVectorResults), + } as any + + mockReranker = { + rerank: vi.fn().mockResolvedValue(testRerankedResults), + } as any + + // Create service with reranker + service = new CodeIndexSearchService( + mockConfigManager, + mockStateManager, + mockEmbedder, + mockVectorStore, + mockReranker, + mockLogger, + ) + }) + + describe("Reranking Enabled Scenarios", () => { + it("should successfully rerank when reranker is enabled and functional", async () => { + const results = await service.searchIndex(testQuery) + + // Verify embedder was called + expect(mockEmbedder.createEmbeddings).toHaveBeenCalledWith([testQuery]) + + // Verify vector search was called with topN limit + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + undefined, + 0.5, + 20, // topN for reranking candidates + ) + + // Verify reranker was called with correct candidates + expect(mockReranker.rerank).toHaveBeenCalledTimes(1) + const rerankCall = (mockReranker.rerank as Mock).mock.calls[0] + expect(rerankCall[0]).toBe(testQuery) + expect(rerankCall[1]).toHaveLength(3) // All vector results + expect(rerankCall[2]).toBe(5) // topK limit + + // Verify results are reranked and ordered correctly + expect(results).toHaveLength(2) + expect(results[0].id).toBe("2") // login.ts first + expect(results[0].score).toBe(0.95) // reranked score + expect(results[1].id).toBe("1") // auth.ts second + expect(results[1].score).toBe(0.85) // reranked score + }) + + it("should pass correct number of candidates (topN) to reranker", async () => { + // Test with many vector results + const manyResults = Array.from({ length: 30 }, (_, i) => ({ + id: `${i}`, + score: 0.9 - i * 0.01, + payload: { + filePath: `file${i}.ts`, + codeChunk: `code${i}`, + startLine: i * 10, + endLine: i * 10 + 5, + }, + })) + + ;(mockVectorStore.search as Mock).mockResolvedValue(manyResults) + + await service.searchIndex(testQuery) + + // Verify vector search requested topN results + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + undefined, + 0.5, + 20, // topN + ) + }) + + it("should limit final results to topK after reranking", async () => { + // Mock reranker to return more than topK results + const manyRerankedResults = Array.from({ length: 10 }, (_, i) => ({ + id: `${i}`, + score: 0.95 - i * 0.05, + })) + + ;(mockReranker.rerank as Mock).mockResolvedValue(manyRerankedResults) + + const results = await service.searchIndex(testQuery) + + // Verify reranker was asked to limit to topK + expect(mockReranker.rerank).toHaveBeenCalledWith( + testQuery, + expect.any(Array), + 5, // topK + ) + + // Results should be limited by what reranker returns + expect(results.length).toBeLessThanOrEqual(10) + }) + + it("should properly map reranked results back to original format", async () => { + const results = await service.searchIndex(testQuery) + + // Verify original payload is preserved but score is updated + expect(results[0]).toEqual({ + id: "2", + score: 0.95, // reranked score + payload: { + filePath: "login.ts", + codeChunk: "function login()", + startLine: 5, + endLine: 15, + }, + }) + }) + + it("should handle directory prefix filtering", async () => { + const directoryPrefix = "src/auth" + + await service.searchIndex(testQuery, directoryPrefix) + + // Verify vector store search was called with normalized prefix + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + "src/auth", // normalized + 0.5, + 20, + ) + }) + }) + + describe("Reranking Error Handling", () => { + it("should fallback to vector search results when reranking fails", async () => { + // Mock reranker to throw error + const rerankError = new Error("Reranker service unavailable") + ;(mockReranker.rerank as Mock).mockRejectedValue(rerankError) + + const results = await service.searchIndex(testQuery) + + // Should log the error + expect(mockLogger).toHaveBeenCalledWith( + "[CodeIndexSearchService] Reranking failed, falling back to vector search results:", + rerankError, + ) + + // Should capture telemetry + expect(TelemetryService.instance.captureEvent).toHaveBeenCalledWith(TelemetryEventName.CODE_INDEX_ERROR, { + error: "Reranker service unavailable", + stack: expect.any(String), + location: "searchIndex-reranking", + }) + + // Should return original vector results limited to topK + expect(results).toHaveLength(3) // All original results fit within topK=5 + expect(results[0].id).toBe("1") // Original order preserved + expect(results[0].score).toBe(0.9) // Original scores preserved + }) + + it("should limit fallback results to topK when reranking fails", async () => { + // Create many vector results + const manyResults = Array.from({ length: 10 }, (_, i) => ({ + id: `${i}`, + score: 0.9 - i * 0.01, + payload: { filePath: `file${i}.ts`, codeChunk: `code${i}`, startLine: i, endLine: i + 5 }, + })) + + ;(mockVectorStore.search as Mock).mockResolvedValue(manyResults) + ;(mockReranker.rerank as Mock).mockRejectedValue(new Error("Rerank failed")) + + const results = await service.searchIndex(testQuery) + + // Should return topK results only + expect(results).toHaveLength(5) // topK = 5 + expect(results[0].id).toBe("0") // First result + expect(results[4].id).toBe("4") // Fifth result + }) + + it("should continue to work even if reranker throws during processing", async () => { + // Mock a more complex error scenario + ;(mockReranker.rerank as Mock).mockImplementation(() => { + throw new TypeError("Cannot read property 'map' of undefined") + }) + + const results = await service.searchIndex(testQuery) + + // Should handle gracefully and return vector results + expect(results).toHaveLength(3) + expect(results[0].score).toBe(0.9) // Original scores + }) + }) + + describe("Reranking Disabled Scenarios", () => { + it("should skip reranking when disabled in config", async () => { + // Use Object.defineProperty to change the getter value + Object.defineProperty(mockConfigManager, "isRerankerEnabled", { + get: () => false, + configurable: true, + }) + + const results = await service.searchIndex(testQuery) + + // Verify vector search used regular maxResults + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + undefined, + 0.5, + 10, // currentSearchMaxResults, not topN + ) + + // Verify reranker was not called + expect(mockReranker.rerank).not.toHaveBeenCalled() + + // Should return original vector results + expect(results).toEqual(testVectorResults) + }) + + it("should skip reranking when reranker instance is not available", async () => { + // Create service without reranker + const serviceWithoutReranker = new CodeIndexSearchService( + mockConfigManager, + mockStateManager, + mockEmbedder, + mockVectorStore, + undefined, // no reranker + mockLogger, + ) + + const results = await serviceWithoutReranker.searchIndex(testQuery) + + // Verify vector search used regular maxResults + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + undefined, + 0.5, + 10, // currentSearchMaxResults + ) + + // Should return original vector results + expect(results).toEqual(testVectorResults) + }) + + it("should use regular maxResults limit when reranking is disabled", async () => { + // Use Object.defineProperty to change the getter values + Object.defineProperty(mockConfigManager, "isRerankerEnabled", { + get: () => false, + configurable: true, + }) + Object.defineProperty(mockConfigManager, "currentSearchMaxResults", { + get: () => 3, + configurable: true, + }) + + // Create more results than limit + const manyResults = Array.from({ length: 10 }, (_, i) => ({ + id: `${i}`, + score: 0.9 - i * 0.01, + payload: { filePath: `file${i}.ts`, codeChunk: `code${i}`, startLine: i, endLine: i + 5 }, + })) + + ;(mockVectorStore.search as Mock).mockResolvedValue(manyResults) + + await service.searchIndex(testQuery) + + // Verify correct limit was used + expect(mockVectorStore.search).toHaveBeenCalledWith( + testEmbedding, + undefined, + 0.5, + 3, // currentSearchMaxResults + ) + }) + }) + + describe("Performance Tracking", () => { + it("should log timing for vector search when reranking is disabled", async () => { + // Use Object.defineProperty to change the getter value + Object.defineProperty(mockConfigManager, "isRerankerEnabled", { + get: () => false, + configurable: true, + }) + + await service.searchIndex(testQuery) + + // Should log vector search timing + expect(mockLogger).toHaveBeenCalledWith( + expect.stringMatching(/\[CodeIndexSearchService\] Vector search completed in \d+ms\. Results: 3/), + ) + }) + + it("should log timing for both vector search and reranking when enabled", async () => { + await service.searchIndex(testQuery) + + // Should log reranking timing + expect(mockLogger).toHaveBeenCalledWith( + expect.stringMatching(/\[CodeIndexSearchService\] Reranking completed in \d+ms\. Input: 3, Output: 2/), + ) + }) + + it("should log timing even when reranking fails", async () => { + ;(mockReranker.rerank as Mock).mockRejectedValue(new Error("Rerank failed")) + + await service.searchIndex(testQuery) + + // Should still log vector search timing in error case + expect(mockLogger).toHaveBeenCalledWith( + "[CodeIndexSearchService] Reranking failed, falling back to vector search results:", + expect.any(Error), + ) + }) + }) + + describe("Error Handling", () => { + it("should throw error when feature is disabled", async () => { + // Use Object.defineProperty to change the getter value + Object.defineProperty(mockConfigManager, "isFeatureEnabled", { + get: () => false, + configurable: true, + }) + + await expect(service.searchIndex(testQuery)).rejects.toThrow( + "Code index feature is disabled or not configured.", + ) + }) + + it("should throw error when feature is not configured", async () => { + // Use Object.defineProperty to change the getter value + Object.defineProperty(mockConfigManager, "isFeatureConfigured", { + get: () => false, + configurable: true, + }) + + await expect(service.searchIndex(testQuery)).rejects.toThrow( + "Code index feature is disabled or not configured.", + ) + }) + + it("should throw error when system is not ready", async () => { + ;(mockStateManager.getCurrentStatus as Mock).mockReturnValue({ + systemStatus: "Idle", + }) + + await expect(service.searchIndex(testQuery)).rejects.toThrow( + "Code index is not ready for search. Current state: Idle", + ) + }) + + it("should allow search during indexing state", async () => { + ;(mockStateManager.getCurrentStatus as Mock).mockReturnValue({ + systemStatus: "Indexing", + }) + + const results = await service.searchIndex(testQuery) + + expect(results).toBeDefined() + expect(mockVectorStore.search).toHaveBeenCalled() + }) + + it("should handle embedding generation failure", async () => { + ;(mockEmbedder.createEmbeddings as Mock).mockResolvedValue({ + embeddings: [], + }) + + await expect(service.searchIndex(testQuery)).rejects.toThrow("Failed to generate embedding for query.") + + // Should set error state + expect(mockStateManager.setSystemState).toHaveBeenCalledWith( + "Error", + "Search failed: Failed to generate embedding for query.", + ) + + // Should capture telemetry + expect(TelemetryService.instance.captureEvent).toHaveBeenCalledWith(TelemetryEventName.CODE_INDEX_ERROR, { + error: "Failed to generate embedding for query.", + stack: expect.any(String), + location: "searchIndex", + }) + }) + }) + + describe("Reranker Integration", () => { + it("should convert vector results to reranker format correctly", async () => { + await service.searchIndex(testQuery) + + const rerankCall = (mockReranker.rerank as Mock).mock.calls[0] + const candidates: RerankCandidate[] = rerankCall[1] + + expect(candidates[0]).toEqual({ + id: "1", + content: "function authenticate()", + metadata: { + filePath: "auth.ts", + startLine: 10, + endLine: 20, + score: 0.9, + }, + }) + }) + + it("should handle empty vector results", async () => { + ;(mockVectorStore.search as Mock).mockResolvedValue([]) + + const results = await service.searchIndex(testQuery) + + // Should not call reranker with empty results + expect(mockReranker.rerank).not.toHaveBeenCalled() + expect(results).toEqual([]) + }) + + it("should handle missing payload in vector results", async () => { + const resultsWithMissingPayload: VectorStoreSearchResult[] = [ + { + id: "1", + score: 0.9, + payload: null, + }, + ] + + ;(mockVectorStore.search as Mock).mockResolvedValue(resultsWithMissingPayload) + + await service.searchIndex(testQuery) + + // Should handle gracefully + const rerankCall = (mockReranker.rerank as Mock).mock.calls[0] + const candidates: RerankCandidate[] = rerankCall[1] + + expect(candidates[0].content).toBe("") // Default empty content + }) + }) +}) diff --git a/src/services/code-index/config-manager.ts b/src/services/code-index/config-manager.ts index 1723f1c2a0..2ab1a25491 100644 --- a/src/services/code-index/config-manager.ts +++ b/src/services/code-index/config-manager.ts @@ -2,6 +2,7 @@ import { ApiHandlerOptions } from "../../shared/api" import { ContextProxy } from "../../core/config/ContextProxy" import { EmbedderProvider } from "./interfaces/manager" import { CodeIndexConfig, PreviousConfigSnapshot } from "./interfaces/config" +import { RerankerConfig, RerankerProvider } from "./interfaces/reranker" import { DEFAULT_SEARCH_MIN_SCORE, DEFAULT_MAX_SEARCH_RESULTS } from "./constants" import { getDefaultModelId, getModelDimension, getModelScoreThreshold } from "../../shared/embeddingModels" @@ -24,6 +25,16 @@ export class CodeIndexConfigManager { private searchMinScore?: number private searchMaxResults?: number + // Reranker configuration + private rerankerEnabled: boolean = false + private rerankerProvider: RerankerProvider = "local" + private rerankerUrl?: string + private rerankerModel?: string + private rerankerTimeout: number = 10000 + private rerankerApiKey?: string + private _rerankerTopN: number = 100 + private _rerankerTopK: number = 20 + constructor(private readonly contextProxy: ContextProxy) { // Initialize with current configuration to avoid false restart triggers this._loadAndSetConfiguration() @@ -50,6 +61,14 @@ export class CodeIndexConfigManager { codebaseIndexEmbedderModelId: "", codebaseIndexSearchMinScore: undefined, codebaseIndexSearchMaxResults: undefined, + // Reranker defaults + codebaseIndexRerankerEnabled: false, + codebaseIndexRerankerProvider: "local", + codebaseIndexRerankerUrl: "http://localhost:8003", + codebaseIndexRerankerModel: "Qwen/Qwen3-Reranker-8B", + codebaseIndexRerankerTopN: 100, + codebaseIndexRerankerTopK: 20, + codebaseIndexRerankerTimeout: 10000, } const { @@ -60,6 +79,14 @@ export class CodeIndexConfigManager { codebaseIndexEmbedderModelId, codebaseIndexSearchMinScore, codebaseIndexSearchMaxResults, + // Reranker settings + codebaseIndexRerankerEnabled, + codebaseIndexRerankerProvider, + codebaseIndexRerankerUrl, + codebaseIndexRerankerModel, + codebaseIndexRerankerTopN, + codebaseIndexRerankerTopK, + codebaseIndexRerankerTimeout, } = codebaseIndexConfig const openAiKey = this.contextProxy?.getSecret("codeIndexOpenAiKey") ?? "" @@ -69,6 +96,7 @@ export class CodeIndexConfigManager { const openAiCompatibleApiKey = this.contextProxy?.getSecret("codebaseIndexOpenAiCompatibleApiKey") ?? "" const geminiApiKey = this.contextProxy?.getSecret("codebaseIndexGeminiApiKey") ?? "" const mistralApiKey = this.contextProxy?.getSecret("codebaseIndexMistralApiKey") ?? "" + const rerankerApiKey = this.contextProxy?.getSecret("codebaseIndexRerankerApiKey") ?? "" // Update instance variables with configuration this.codebaseIndexEnabled = codebaseIndexEnabled ?? true @@ -77,6 +105,18 @@ export class CodeIndexConfigManager { this.searchMinScore = codebaseIndexSearchMinScore this.searchMaxResults = codebaseIndexSearchMaxResults + // Load reranker configuration + this.rerankerEnabled = codebaseIndexRerankerEnabled ?? false + this.rerankerProvider = codebaseIndexRerankerProvider ?? "local" + // Only apply defaults if the value is null or missing from config (not undefined) + this.rerankerUrl = codebaseIndexRerankerUrl !== undefined ? codebaseIndexRerankerUrl : "http://localhost:8003" + this.rerankerModel = + codebaseIndexRerankerModel !== undefined ? codebaseIndexRerankerModel : "Qwen/Qwen3-Reranker-8B" + this._rerankerTopN = codebaseIndexRerankerTopN ?? 100 + this._rerankerTopK = codebaseIndexRerankerTopK ?? 20 + this.rerankerTimeout = codebaseIndexRerankerTimeout ?? 10000 + this.rerankerApiKey = rerankerApiKey + // Validate and set model dimension const rawDimension = codebaseIndexConfig.codebaseIndexEmbedderModelDimension if (rawDimension !== undefined && rawDimension !== null) { @@ -162,6 +202,11 @@ export class CodeIndexConfigManager { mistralApiKey: this.mistralOptions?.apiKey ?? "", qdrantUrl: this.qdrantUrl ?? "", qdrantApiKey: this.qdrantApiKey ?? "", + rerankerEnabled: this.rerankerEnabled, + rerankerProvider: this.rerankerProvider, + rerankerUrl: this.rerankerUrl, + rerankerModel: this.rerankerModel, + rerankerApiKey: this.rerankerApiKey, } // Refresh secrets from VSCode storage to ensure we have the latest values @@ -257,6 +302,11 @@ export class CodeIndexConfigManager { const prevMistralApiKey = prev?.mistralApiKey ?? "" const prevQdrantUrl = prev?.qdrantUrl ?? "" const prevQdrantApiKey = prev?.qdrantApiKey ?? "" + const prevRerankerEnabled = prev?.rerankerEnabled ?? false + const prevRerankerProvider = prev?.rerankerProvider ?? "local" + const prevRerankerUrl = prev?.rerankerUrl ?? "" + const prevRerankerModel = prev?.rerankerModel ?? "" + const prevRerankerApiKey = prev?.rerankerApiKey ?? "" // 1. Transition from disabled/unconfigured to enabled/configured if ((!prevEnabled || !prevConfigured) && this.codebaseIndexEnabled && nowConfigured) { @@ -294,6 +344,11 @@ export class CodeIndexConfigManager { const currentMistralApiKey = this.mistralOptions?.apiKey ?? "" const currentQdrantUrl = this.qdrantUrl ?? "" const currentQdrantApiKey = this.qdrantApiKey ?? "" + const currentRerankerEnabled = this.rerankerEnabled + const currentRerankerProvider = this.rerankerProvider + const currentRerankerUrl = this.rerankerUrl + const currentRerankerModel = this.rerankerModel + const currentRerankerApiKey = this.rerankerApiKey if (prevOpenAiKey !== currentOpenAiKey) { return true @@ -303,6 +358,26 @@ export class CodeIndexConfigManager { return true } + if (prevRerankerEnabled !== currentRerankerEnabled) { + return true + } + + if (prevRerankerProvider !== currentRerankerProvider) { + return true + } + + if (prevRerankerUrl !== currentRerankerUrl) { + return true + } + + if (prevRerankerModel !== currentRerankerModel) { + return true + } + + if (prevRerankerApiKey !== currentRerankerApiKey) { + return true + } + if ( prevOpenAiCompatibleBaseUrl !== currentOpenAiCompatibleBaseUrl || prevOpenAiCompatibleApiKey !== currentOpenAiCompatibleApiKey @@ -327,6 +402,26 @@ export class CodeIndexConfigManager { return true } + if (prevRerankerEnabled !== currentRerankerEnabled) { + return true + } + + if (prevRerankerProvider !== currentRerankerProvider) { + return true + } + + if (prevRerankerUrl !== currentRerankerUrl) { + return true + } + + if (prevRerankerModel !== currentRerankerModel) { + return true + } + + if (prevRerankerApiKey !== currentRerankerApiKey) { + return true + } + // Vector dimension changes (still important for compatibility) if (this._hasVectorDimensionChanged(prevProvider, prev?.modelId)) { return true @@ -379,6 +474,7 @@ export class CodeIndexConfigManager { qdrantApiKey: this.qdrantApiKey, searchMinScore: this.currentSearchMinScore, searchMaxResults: this.currentSearchMaxResults, + rerankerConfig: this.getRerankerConfig(), } } @@ -460,4 +556,41 @@ export class CodeIndexConfigManager { public get currentSearchMaxResults(): number { return this.searchMaxResults ?? DEFAULT_MAX_SEARCH_RESULTS } + + /** + * Gets whether the reranker is enabled + */ + public get isRerankerEnabled(): boolean { + return this.rerankerEnabled && this.isFeatureEnabled + } + + /** + * Gets the complete reranker configuration + */ + public getRerankerConfig(): RerankerConfig { + return { + enabled: this.rerankerEnabled, + provider: this.rerankerProvider, + url: this.rerankerUrl, + apiKey: this.rerankerApiKey, + model: this.rerankerModel, + topN: this._rerankerTopN, + topK: this._rerankerTopK, + timeout: this.rerankerTimeout, + } + } + + /** + * Gets the reranker topN value (number of candidates to send for reranking) + */ + public get rerankerTopN(): number { + return this._rerankerTopN + } + + /** + * Gets the reranker topK value (number of final results to return) + */ + public get rerankerTopK(): number { + return this._rerankerTopK + } } diff --git a/src/services/code-index/interfaces/config.ts b/src/services/code-index/interfaces/config.ts index 9098a60091..cc8f33a35d 100644 --- a/src/services/code-index/interfaces/config.ts +++ b/src/services/code-index/interfaces/config.ts @@ -1,5 +1,6 @@ import { ApiHandlerOptions } from "../../../shared/api" // Adjust path if needed import { EmbedderProvider } from "./manager" +import { RerankerConfig } from "./reranker" /** * Configuration state for the code indexing feature @@ -18,6 +19,7 @@ export interface CodeIndexConfig { qdrantApiKey?: string searchMinScore?: number searchMaxResults?: number + rerankerConfig?: RerankerConfig } /** @@ -37,4 +39,9 @@ export type PreviousConfigSnapshot = { mistralApiKey?: string qdrantUrl?: string qdrantApiKey?: string + rerankerEnabled?: boolean + rerankerProvider?: string + rerankerUrl?: string + rerankerModel?: string + rerankerApiKey?: string } diff --git a/src/services/code-index/interfaces/index.ts b/src/services/code-index/interfaces/index.ts index 20dd55ad89..ddb9e033e9 100644 --- a/src/services/code-index/interfaces/index.ts +++ b/src/services/code-index/interfaces/index.ts @@ -2,3 +2,4 @@ export * from "./embedder" export * from "./vector-store" export * from "./file-processor" export * from "./manager" +export * from "./reranker" diff --git a/src/services/code-index/interfaces/reranker.ts b/src/services/code-index/interfaces/reranker.ts new file mode 100644 index 0000000000..62ca67d0db --- /dev/null +++ b/src/services/code-index/interfaces/reranker.ts @@ -0,0 +1,68 @@ +/** + * Reranker provider types + */ +export type RerankerProvider = "local" | "cohere" | "openai" | "custom" + +/** + * Configuration for the reranker + */ +export interface RerankerConfig { + enabled: boolean + provider: RerankerProvider + url?: string + apiKey?: string + model?: string + topN: number + topK: number + timeout: number +} + +/** + * Candidate document for reranking + */ +export interface RerankCandidate { + id: string + content: string + metadata?: { + filePath?: string + startLine?: number + endLine?: number + score?: number + [key: string]: any + } +} + +/** + * Result from reranking + */ +export interface RerankResult { + id: string + score: number + originalScore?: number +} + +/** + * Interface for reranking implementations + */ +export interface IReranker { + /** + * Rerank the given candidates based on the query + * @param query The search query + * @param candidates The candidate documents to rerank + * @param maxResults Optional maximum number of results to return + * @returns Reranked results with scores + */ + rerank(query: string, candidates: RerankCandidate[], maxResults?: number): Promise + + /** + * Validates reranker configuration + * @returns Promise resolving to validation result + */ + validateConfiguration(): Promise<{ valid: boolean; error?: string }> + + /** + * Gets reranker health status + * @returns Promise resolving to health status + */ + healthCheck(): Promise +} diff --git a/src/services/code-index/manager.ts b/src/services/code-index/manager.ts index 18e0752c34..7d8de43f62 100644 --- a/src/services/code-index/manager.ts +++ b/src/services/code-index/manager.ts @@ -15,10 +15,12 @@ import path from "path" import { t } from "../../i18n" import { TelemetryService } from "@roo-code/telemetry" import { TelemetryEventName } from "@roo-code/types" +import { LogFunction } from "../../utils/outputChannelLogger" export class CodeIndexManager { // --- Singleton Implementation --- private static instances = new Map() // Map workspace path to instance + private static loggerFunction: LogFunction | undefined // Specialized class instances private _configManager: CodeIndexConfigManager | undefined @@ -53,14 +55,20 @@ export class CodeIndexManager { CodeIndexManager.instances.clear() } + public static setLogger(logger: LogFunction): void { + CodeIndexManager.loggerFunction = logger + } + private readonly workspacePath: string private readonly context: vscode.ExtensionContext + private readonly logger: LogFunction // Private constructor for singleton pattern private constructor(workspacePath: string, context: vscode.ExtensionContext) { this.workspacePath = workspacePath this.context = context this._stateManager = new CodeIndexStateManager() + this.logger = CodeIndexManager.loggerFunction || ((...args: unknown[]) => console.log(...args)) } // --- Public API --- @@ -234,6 +242,7 @@ export class CodeIndexManager { this._configManager!, this.workspacePath, this._cacheManager!, + this.logger, ) const ignoreInstance = ignore() @@ -274,6 +283,14 @@ export class CodeIndexManager { throw new Error(errorMessage) } + // Create reranker instance if enabled + const reranker = await this._serviceFactory.createReranker() + if (reranker) { + this.logger("[CodeIndexManager] Reranker successfully created") + } else if (this._configManager!.isRerankerEnabled) { + this.logger("[CodeIndexManager] Reranker is enabled but failed to create instance") + } + // (Re)Initialize orchestrator this._orchestrator = new CodeIndexOrchestrator( this._configManager!, @@ -285,12 +302,14 @@ export class CodeIndexManager { fileWatcher, ) - // (Re)Initialize search service + // (Re)Initialize search service with optional reranker this._searchService = new CodeIndexSearchService( this._configManager!, this._stateManager, embedder, vectorStore, + reranker, // Pass the reranker instance (may be undefined) + this.logger, ) // Clear any error state after successful recreation diff --git a/src/services/code-index/rerankers/base.ts b/src/services/code-index/rerankers/base.ts new file mode 100644 index 0000000000..1d22a6d93e --- /dev/null +++ b/src/services/code-index/rerankers/base.ts @@ -0,0 +1,103 @@ +import { IReranker, RerankCandidate, RerankResult, RerankerConfig } from "../interfaces/reranker" + +/** + * Abstract base class for reranker implementations + * Provides common functionality and structure for all rerankers + */ +export abstract class BaseReranker implements IReranker { + protected readonly provider: string + protected readonly config: RerankerConfig + protected readonly logger: Console + + constructor(provider: string, config: RerankerConfig) { + this.provider = provider + this.config = config + this.logger = console + } + + /** + * Reranks search results based on query relevance + * @param query The search query + * @param results Candidate results to rerank + * @param maxResults Maximum number of results to return + * @returns Promise resolving to reranked results + */ + abstract rerank(query: string, results: RerankCandidate[], maxResults?: number): Promise + + /** + * Validates reranker configuration + * @returns Promise resolving to validation result + */ + abstract validateConfiguration(): Promise<{ valid: boolean; error?: string }> + + /** + * Gets reranker health status + * @returns Promise resolving to health status + */ + abstract healthCheck(): Promise + + /** + * Common error handler for reranker operations + * @param error The error to handle + * @param operation The operation that failed + * @throws Error with formatted message + */ + protected handleError(error: unknown, operation: string): never { + const errorMessage = error instanceof Error ? error.message : String(error) + const fullMessage = `${this.provider} reranker ${operation} failed: ${errorMessage}` + + this.logger.error(fullMessage, error) + throw new Error(fullMessage) + } + + /** + * Validates common configuration requirements + * @returns Validation result + */ + protected validateCommonConfig(): { valid: boolean; error?: string } { + if (!this.config.enabled) { + return { valid: false, error: "Reranker is not enabled" } + } + + if (this.config.topN <= 0) { + return { valid: false, error: "topN must be greater than 0" } + } + + if (this.config.topK <= 0) { + return { valid: false, error: "topK must be greater than 0" } + } + + if (this.config.topK > this.config.topN) { + return { valid: false, error: "topK cannot be greater than topN" } + } + + return { valid: true } + } + + /** + * Filters and limits results based on configuration + * @param results The reranked results + * @param maxResults Maximum number of results requested + * @returns Filtered results + */ + protected filterResults(results: RerankResult[], maxResults?: number): RerankResult[] { + const limit = maxResults ?? this.config.topK + return results.slice(0, Math.min(limit, results.length)) + } + + /** + * Assigns ranks to results based on scores + * @param results Results with scores + * @returns Results with assigned ranks + */ + protected assignRanks(results: RerankResult[]): RerankResult[] { + // Sort by score descending + const sorted = [...results].sort((a, b) => b.score - a.score) + + // Assign ranks + return sorted.map((result, index) => ({ + ...result, + rank: index + 1, + })) + } +} diff --git a/src/services/code-index/rerankers/factory.ts b/src/services/code-index/rerankers/factory.ts new file mode 100644 index 0000000000..c0bd4e3318 --- /dev/null +++ b/src/services/code-index/rerankers/factory.ts @@ -0,0 +1,164 @@ +import { IReranker, RerankerConfig } from "../interfaces/reranker" +import { LocalReranker } from "./local" + +/** + * Factory class for creating reranker instances based on configuration + */ +export class RerankerFactory { + /** + * Creates a reranker instance based on the provided configuration + * @param config The reranker configuration + * @returns IReranker instance or undefined if configuration is invalid + */ + static async create(config: RerankerConfig): Promise { + try { + // Check if reranking is enabled + if (!config.enabled) { + console.log("Reranking is disabled in configuration") + return undefined + } + + // Create appropriate reranker based on provider + let reranker: IReranker | undefined + + switch (config.provider) { + case "local": + reranker = new LocalReranker(config) + break + + case "cohere": + // TODO: Implement Cohere reranker + console.warn("Cohere reranker not yet implemented") + return undefined + + case "openai": + // TODO: Implement OpenAI reranker + console.warn("OpenAI reranker not yet implemented") + return undefined + + case "custom": + // TODO: Implement custom reranker + console.warn("Custom reranker not yet implemented") + return undefined + + default: + console.error(`Unknown reranker provider: ${config.provider}`) + return undefined + } + + // Validate the configuration + const validation = await reranker.validateConfiguration() + if (!validation.valid) { + console.error(`Reranker configuration validation failed: ${validation.error}`) + return undefined + } + + // Perform initial health check + const isHealthy = await reranker.healthCheck() + if (!isHealthy) { + console.warn("Reranker health check failed, but continuing with initialization") + } + + console.log(`Successfully created ${config.provider} reranker`) + return reranker + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + console.error(`Failed to create reranker: ${errorMessage}`) + return undefined + } + } + + /** + * Validates a reranker configuration without creating an instance + * @param config The reranker configuration to validate + * @returns Validation result + */ + static validateConfig(config: RerankerConfig): { valid: boolean; error?: string } { + // Check required fields + if (!config.provider) { + return { valid: false, error: "Provider is required" } + } + + if (!config.enabled) { + return { valid: true } // Disabled is valid + } + + // Check provider-specific requirements + switch (config.provider) { + case "local": + if (!config.url) { + return { valid: false, error: "Local reranker requires a URL" } + } + if (!config.apiKey) { + return { valid: false, error: "Local reranker requires an API key" } + } + break + + case "cohere": + if (!config.apiKey) { + return { valid: false, error: "Cohere reranker requires an API key" } + } + break + + case "openai": + if (!config.apiKey) { + return { valid: false, error: "OpenAI reranker requires an API key" } + } + break + + case "custom": + if (!config.url) { + return { valid: false, error: "Custom reranker requires a URL" } + } + break + + default: + return { valid: false, error: `Unknown provider: ${config.provider}` } + } + + // Validate numeric fields + if (config.topN !== undefined && config.topN <= 0) { + return { valid: false, error: "topN must be greater than 0" } + } + + if (config.topK !== undefined && config.topK <= 0) { + return { valid: false, error: "topK must be greater than 0" } + } + + if (config.topN !== undefined && config.topK !== undefined && config.topK > config.topN) { + return { valid: false, error: "topK cannot be greater than topN" } + } + + if (config.timeout !== undefined && config.timeout <= 0) { + return { valid: false, error: "timeout must be greater than 0" } + } + + return { valid: true } + } + + /** + * Gets the list of supported reranker providers + * @returns Array of supported provider names + */ + static getSupportedProviders(): string[] { + return ["local", "cohere", "openai", "custom"] + } + + /** + * Checks if a provider is currently implemented + * @param provider The provider name to check + * @returns True if the provider is implemented + */ + static isProviderImplemented(provider: string): boolean { + switch (provider) { + case "local": + return true + case "cohere": + case "openai": + case "custom": + return false + default: + return false + } + } +} diff --git a/src/services/code-index/rerankers/index.ts b/src/services/code-index/rerankers/index.ts new file mode 100644 index 0000000000..c081301c7d --- /dev/null +++ b/src/services/code-index/rerankers/index.ts @@ -0,0 +1,11 @@ +// Export base class +export { BaseReranker } from "./base" + +// Export implementations +export { LocalReranker } from "./local" + +// Export factory +export { RerankerFactory } from "./factory" + +// Re-export interfaces and types from the interfaces module for convenience +export type { IReranker, RerankCandidate, RerankResult, RerankerConfig, RerankerProvider } from "../interfaces/reranker" diff --git a/src/services/code-index/rerankers/local.ts b/src/services/code-index/rerankers/local.ts new file mode 100644 index 0000000000..1a66887f01 --- /dev/null +++ b/src/services/code-index/rerankers/local.ts @@ -0,0 +1,227 @@ +import axios, { AxiosInstance, AxiosError } from "axios" +import { BaseReranker } from "./base" +import { RerankCandidate, RerankResult, RerankerConfig } from "../interfaces/reranker" + +/** + * Local reranker implementation that communicates with a user's specific reranker API + */ +export class LocalReranker extends BaseReranker { + private readonly axiosInstance: AxiosInstance + private readonly baseUrl: string + private readonly apiKey: string + private readonly model?: string + + constructor(config: RerankerConfig) { + super("local", config) + + if (!config.url) { + throw new Error("Local reranker requires a base URL") + } + + if (!config.apiKey) { + throw new Error("Local reranker requires an API key") + } + + this.baseUrl = config.url.replace(/\/$/, "") // Remove trailing slash + this.apiKey = config.apiKey + this.model = config.model + + // Create axios instance with default configuration + this.axiosInstance = axios.create({ + baseURL: this.baseUrl, + timeout: config.timeout ?? 30000, // Default 30 seconds + headers: { + Authorization: `Bearer ${this.apiKey}`, + "Content-Type": "application/json", + }, + }) + } + + /** + * Reranks search results using the local reranker API + */ + async rerank(query: string, results: RerankCandidate[], maxResults?: number): Promise { + try { + // Validate inputs + if (!query || query.trim().length === 0) { + throw new Error("Query cannot be empty") + } + + if (!results || results.length === 0) { + return [] + } + + // Limit candidates to topN from config + const candidatesToRerank = results.slice(0, this.config.topN) + + // Convert RerankCandidate[] to API format + const documents = candidatesToRerank.map((candidate) => candidate.content) + + // Prepare request payload + const payload: any = { + query, + documents, + } + + // Add model if specified + if (this.model) { + payload.model = this.model + } + + // Add max_results if specified (using topK as default) + payload.max_results = maxResults ?? this.config.topK + + this.logger.log(`Reranking ${documents.length} documents for query: "${query}"`) + + // Make the API request + const response = await this.axiosInstance.post("/rerank", payload) + + // Validate response + if (!response.data || !Array.isArray(response.data)) { + throw new Error("Invalid response format from reranker API") + } + + // Map response back to RerankResult[] format + const rerankResults: RerankResult[] = response.data.map((item: any, index: number) => { + // Find the original candidate by matching index + const originalCandidate = candidatesToRerank[index] + + if (!originalCandidate) { + throw new Error(`No candidate found for index ${index}`) + } + + return { + id: originalCandidate.id, + score: item.score ?? 0, + rank: item.rank ?? index + 1, + } + }) + + // Sort by score descending and assign proper ranks + const rankedResults = this.assignRanks(rerankResults) + + // Filter results based on maxResults or topK + return this.filterResults(rankedResults, maxResults) + } catch (error) { + if (axios.isAxiosError(error)) { + const axiosError = error as AxiosError + + if (axiosError.response) { + // The request was made and the server responded with a status code + // that falls out of the range of 2xx + const status = axiosError.response.status + const data = axiosError.response.data + + if (status === 401) { + this.handleError(new Error("Invalid API key"), "authentication") + } else if (status === 404) { + this.handleError(new Error(`Rerank endpoint not found at ${this.baseUrl}/rerank`), "endpoint") + } else if (status === 429) { + this.handleError(new Error("Rate limit exceeded"), "rate-limit") + } else { + this.handleError(new Error(`API error (${status}): ${JSON.stringify(data)}`), "rerank") + } + } else if (axiosError.request) { + // The request was made but no response was received + this.handleError(new Error(`No response from reranker API at ${this.baseUrl}`), "connection") + } else { + // Something happened in setting up the request + this.handleError(error, "request setup") + } + } + + this.handleError(error, "rerank") + } + } + + /** + * Validates the reranker configuration by making a test request + */ + async validateConfiguration(): Promise<{ valid: boolean; error?: string }> { + try { + // First validate common config + const commonValidation = this.validateCommonConfig() + if (!commonValidation.valid) { + return commonValidation + } + + // Test the rerank endpoint with minimal data + const testQuery = "test" + const testDocuments = ["test document"] + + const payload: any = { + query: testQuery, + documents: testDocuments, + max_results: 1, + } + + if (this.model) { + payload.model = this.model + } + + const response = await this.axiosInstance.post("/rerank", payload) + + // Validate response structure + if (!response.data || !Array.isArray(response.data)) { + return { + valid: false, + error: "Invalid response format from reranker API", + } + } + + if (response.data.length > 0) { + const firstResult = response.data[0] + if (typeof firstResult.score !== "number") { + return { + valid: false, + error: 'Reranker API response missing required "score" field', + } + } + } + + return { valid: true } + } catch (error) { + if (axios.isAxiosError(error)) { + const axiosError = error as AxiosError + + if (axiosError.response?.status === 401) { + return { valid: false, error: "Invalid API key" } + } else if (axiosError.response?.status === 404) { + return { valid: false, error: `Rerank endpoint not found at ${this.baseUrl}/rerank` } + } else if (axiosError.request) { + return { valid: false, error: `Cannot connect to reranker API at ${this.baseUrl}` } + } + } + + const errorMessage = error instanceof Error ? error.message : String(error) + return { valid: false, error: `Configuration validation failed: ${errorMessage}` } + } + } + + /** + * Performs a health check on the reranker API + */ + async healthCheck(): Promise { + try { + // Try a minimal rerank request + const payload: any = { + query: "health check", + documents: ["test"], + max_results: 1, + } + + if (this.model) { + payload.model = this.model + } + + const response = await this.axiosInstance.post("/rerank", payload, { + timeout: 5000, // 5 second timeout for health check + }) + + return response.status === 200 && Array.isArray(response.data) + } catch (error) { + this.logger.error("Health check failed:", error) + return false + } + } +} diff --git a/src/services/code-index/search-service.ts b/src/services/code-index/search-service.ts index a56f5cc674..11418810d3 100644 --- a/src/services/code-index/search-service.ts +++ b/src/services/code-index/search-service.ts @@ -2,21 +2,29 @@ import * as path from "path" import { VectorStoreSearchResult } from "./interfaces" import { IEmbedder } from "./interfaces/embedder" import { IVectorStore } from "./interfaces/vector-store" +import { IReranker, RerankCandidate } from "./interfaces/reranker" import { CodeIndexConfigManager } from "./config-manager" import { CodeIndexStateManager } from "./state-manager" import { TelemetryService } from "@roo-code/telemetry" import { TelemetryEventName } from "@roo-code/types" +import { LogFunction } from "../../utils/outputChannelLogger" /** * Service responsible for searching the code index. */ export class CodeIndexSearchService { + private readonly logger: LogFunction + constructor( private readonly configManager: CodeIndexConfigManager, private readonly stateManager: CodeIndexStateManager, private readonly embedder: IEmbedder, private readonly vectorStore: IVectorStore, - ) {} + private readonly reranker?: IReranker, // Add optional reranker + logger?: LogFunction, + ) { + this.logger = logger || ((...args: unknown[]) => console.log(...args)) + } /** * Searches the code index for relevant content. @@ -31,9 +39,6 @@ export class CodeIndexSearchService { throw new Error("Code index feature is disabled or not configured.") } - const minScore = this.configManager.currentSearchMinScore - const maxResults = this.configManager.currentSearchMaxResults - const currentState = this.stateManager.getCurrentStatus().systemStatus if (currentState !== "Indexed" && currentState !== "Indexing") { // Allow search during Indexing too @@ -54,11 +59,52 @@ export class CodeIndexSearchService { normalizedPrefix = path.normalize(directoryPrefix) } - // Perform search + // Determine if we should use reranking + const useReranking = this.configManager.isRerankerEnabled && this.reranker + + // Get search parameters + const minScore = this.configManager.currentSearchMinScore + const maxResults = useReranking + ? this.configManager.rerankerTopN // Get more candidates for reranking + : this.configManager.currentSearchMaxResults + + // Perform vector search + const startTime = Date.now() const results = await this.vectorStore.search(vector, normalizedPrefix, minScore, maxResults) + const vectorSearchTime = Date.now() - startTime + + // Apply reranking if enabled + if (useReranking && this.reranker && results.length > 0) { + const rerankStartTime = Date.now() + try { + const rerankedResults = await this.applyReranking(query, results) + const rerankTime = Date.now() - rerankStartTime + this.logger( + `[CodeIndexSearchService] Reranking completed in ${rerankTime}ms. Input: ${results.length}, Output: ${rerankedResults.length}`, + ) + return rerankedResults + } catch (rerankError) { + // Log error but don't fail the search + this.logger( + "[CodeIndexSearchService] Reranking failed, falling back to vector search results:", + rerankError, + ) + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: (rerankError as Error).message, + stack: (rerankError as Error).stack, + location: "searchIndex-reranking", + }) + // Return original results limited to topK + return results.slice(0, this.configManager.rerankerTopK) + } + } + + this.logger( + `[CodeIndexSearchService] Vector search completed in ${vectorSearchTime}ms. Results: ${results.length}`, + ) return results } catch (error) { - console.error("[CodeIndexSearchService] Error during search:", error) + this.logger("[CodeIndexSearchService] Error during search:", error) this.stateManager.setSystemState("Error", `Search failed: ${(error as Error).message}`) // Capture telemetry for the error @@ -71,4 +117,41 @@ export class CodeIndexSearchService { throw error // Re-throw the error after setting state } } + + /** + * Applies reranking to search results + * @param query The original search query + * @param results The vector search results to rerank + * @returns Reranked and filtered results + */ + private async applyReranking( + query: string, + results: VectorStoreSearchResult[], + ): Promise { + // Convert to reranker format + const candidates: RerankCandidate[] = results.map((r) => ({ + id: r.id.toString(), + content: r.payload?.codeChunk || "", + metadata: { + filePath: r.payload?.filePath, + startLine: r.payload?.startLine, + endLine: r.payload?.endLine, + score: r.score, + }, + })) + + // Rerank results + const rerankedResults = await this.reranker!.rerank(query, candidates, this.configManager.rerankerTopK) + + // Map back to original format, preserving payload + const resultMap = new Map(results.map((r) => [r.id.toString(), r])) + + return rerankedResults.map((reranked) => { + const original = resultMap.get(reranked.id)! + return { + ...original, + score: reranked.score, // Use reranked score + } + }) + } } diff --git a/src/services/code-index/service-factory.ts b/src/services/code-index/service-factory.ts index 68b0f5c0bc..d802ecbfbd 100644 --- a/src/services/code-index/service-factory.ts +++ b/src/services/code-index/service-factory.ts @@ -7,13 +7,15 @@ import { MistralEmbedder } from "./embedders/mistral" import { EmbedderProvider, getDefaultModelId, getModelDimension } from "../../shared/embeddingModels" import { QdrantVectorStore } from "./vector-store/qdrant-client" import { codeParser, DirectoryScanner, FileWatcher } from "./processors" -import { ICodeParser, IEmbedder, IFileWatcher, IVectorStore } from "./interfaces" +import { ICodeParser, IEmbedder, IFileWatcher, IVectorStore, IReranker } from "./interfaces" import { CodeIndexConfigManager } from "./config-manager" import { CacheManager } from "./cache-manager" import { Ignore } from "ignore" import { t } from "../../i18n" import { TelemetryService } from "@roo-code/telemetry" import { TelemetryEventName } from "@roo-code/types" +import { RerankerFactory } from "./rerankers/factory" +import { LogFunction } from "../../utils/outputChannelLogger" /** * Factory class responsible for creating and configuring code indexing service dependencies. @@ -23,6 +25,7 @@ export class CodeIndexServiceFactory { private readonly configManager: CodeIndexConfigManager, private readonly workspacePath: string, private readonly cacheManager: CacheManager, + private readonly logger: LogFunction = (...args: unknown[]) => console.log(...args), ) {} /** @@ -140,6 +143,41 @@ export class CodeIndexServiceFactory { return new QdrantVectorStore(this.workspacePath, config.qdrantUrl, vectorSize, config.qdrantApiKey) } + /** + * Creates a reranker instance based on the current configuration. + * @returns Promise resolving to IReranker instance or undefined if disabled/invalid + */ + public async createReranker(): Promise { + try { + const rerankerConfig = this.configManager.getRerankerConfig() + + if (!rerankerConfig.enabled) { + this.logger("Reranker is disabled in configuration") + return undefined + } + + const reranker = await RerankerFactory.create(rerankerConfig) + + if (reranker) { + this.logger(`Successfully created ${rerankerConfig.provider} reranker`) + } else { + this.logger("Failed to create reranker instance") + } + + return reranker + } catch (error) { + // Capture telemetry for the error + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + location: "createReranker", + }) + + this.logger("Error creating reranker:", error) + return undefined + } + } + /** * Creates a directory scanner instance with its required dependencies. */ diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index cb8759d851..c1c07314f7 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -265,12 +265,22 @@ export interface WebviewMessage { codebaseIndexSearchMaxResults?: number codebaseIndexSearchMinScore?: number + // Reranker settings + codebaseIndexRerankerEnabled?: boolean + codebaseIndexRerankerProvider?: "local" | "cohere" | "openai" | "custom" + codebaseIndexRerankerUrl?: string + codebaseIndexRerankerModel?: string + codebaseIndexRerankerTopN?: number + codebaseIndexRerankerTopK?: number + codebaseIndexRerankerTimeout?: number + // Secret settings codeIndexOpenAiKey?: string codeIndexQdrantApiKey?: string codebaseIndexOpenAiCompatibleApiKey?: string codebaseIndexGeminiApiKey?: string codebaseIndexMistralApiKey?: string + codebaseIndexRerankerApiKey?: string } } diff --git a/src/tests/services/code-index/config-manager.test.ts b/src/tests/services/code-index/config-manager.test.ts new file mode 100644 index 0000000000..cf898c53fd --- /dev/null +++ b/src/tests/services/code-index/config-manager.test.ts @@ -0,0 +1,394 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import { CodeIndexConfigManager } from "../../../services/code-index/config-manager" +import { ContextProxy } from "../../../core/config/ContextProxy" +import { RerankerProvider } from "../../../services/code-index/interfaces/reranker" + +// Mock the ContextProxy +vi.mock("../../../core/config/ContextProxy") + +// Mock the embeddingModels module +vi.mock("../../../shared/embeddingModels", () => ({ + getDefaultModelId: vi.fn().mockReturnValue("text-embedding-ada-002"), + getModelDimension: vi.fn().mockReturnValue(1536), + getModelScoreThreshold: vi.fn().mockReturnValue(0.3), +})) + +describe("CodeIndexConfigManager - Reranker Configuration", () => { + let configManager: CodeIndexConfigManager + let mockContextProxy: any + + const mockGlobalState = { + codebaseIndexEnabled: true, + codebaseIndexQdrantUrl: "http://localhost:6333", + codebaseIndexEmbedderProvider: "openai", + codebaseIndexEmbedderBaseUrl: "", + codebaseIndexEmbedderModelId: "", + codebaseIndexSearchMinScore: undefined, + codebaseIndexSearchMaxResults: undefined, + // Reranker configuration + codebaseIndexRerankerEnabled: true, + codebaseIndexRerankerProvider: "local", + codebaseIndexRerankerUrl: "http://localhost:8080", + codebaseIndexRerankerModel: "ms-marco-MiniLM-L-6-v2", + codebaseIndexRerankerTopN: 100, + codebaseIndexRerankerTopK: 20, + codebaseIndexRerankerTimeout: 10000, + } + + beforeEach(() => { + vi.clearAllMocks() + + // Create mock context proxy + mockContextProxy = { + getGlobalState: vi.fn().mockReturnValue(mockGlobalState), + getSecret: vi.fn().mockImplementation((key: string) => { + const secrets: any = { + codeIndexOpenAiKey: "test-openai-key", + codeIndexQdrantApiKey: "test-qdrant-key", + codebaseIndexRerankerApiKey: "test-reranker-key", + } + return secrets[key] || "" + }), + refreshSecrets: vi.fn().mockResolvedValue(undefined), + } + + // Create config manager with mock + configManager = new CodeIndexConfigManager(mockContextProxy as any) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe("isRerankerEnabled", () => { + it("should return true when reranker and feature are enabled", () => { + expect(configManager.isRerankerEnabled).toBe(true) + }) + + it("should return false when reranker is disabled", () => { + const disabledState = { + ...mockGlobalState, + codebaseIndexRerankerEnabled: false, + } + mockContextProxy.getGlobalState.mockReturnValue(disabledState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.isRerankerEnabled).toBe(false) + }) + + it("should return false when feature is disabled even if reranker is enabled", () => { + const disabledFeatureState = { + ...mockGlobalState, + codebaseIndexEnabled: false, + } + mockContextProxy.getGlobalState.mockReturnValue(disabledFeatureState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.isRerankerEnabled).toBe(false) + }) + + it("should use default false when reranker enabled is undefined", () => { + const undefinedState = { + ...mockGlobalState, + codebaseIndexRerankerEnabled: undefined, + } + mockContextProxy.getGlobalState.mockReturnValue(undefinedState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.isRerankerEnabled).toBe(false) + }) + }) + + describe("getRerankerConfig", () => { + it("should return complete reranker configuration", () => { + const config = configManager.getRerankerConfig() + + expect(config).toEqual({ + enabled: true, + provider: "local", + url: "http://localhost:8080", + apiKey: "test-reranker-key", + model: "ms-marco-MiniLM-L-6-v2", + topN: 100, + topK: 20, + timeout: 10000, + }) + }) + + it("should return config with defaults when values are undefined", () => { + const minimalState = { + ...mockGlobalState, + codebaseIndexRerankerEnabled: undefined, + codebaseIndexRerankerProvider: undefined, + codebaseIndexRerankerUrl: undefined, + codebaseIndexRerankerModel: undefined, + codebaseIndexRerankerTopN: undefined, + codebaseIndexRerankerTopK: undefined, + codebaseIndexRerankerTimeout: undefined, + } + mockContextProxy.getGlobalState.mockReturnValue(minimalState) + mockContextProxy.getSecret.mockReturnValue("") + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + + expect(config).toEqual({ + enabled: false, + provider: "local", + url: "http://localhost:8003", // Default value + apiKey: "", + model: "Qwen/Qwen3-Reranker-8B", // Default value + topN: 100, + topK: 20, + timeout: 10000, + }) + }) + + it("should handle different provider types", () => { + const providers: RerankerProvider[] = ["local", "cohere", "openai", "custom"] + + providers.forEach((provider) => { + const providerState = { + ...mockGlobalState, + codebaseIndexRerankerProvider: provider, + } + mockContextProxy.getGlobalState.mockReturnValue(providerState) + const cm = new CodeIndexConfigManager(mockContextProxy as any) + + const config = cm.getRerankerConfig() + expect(config.provider).toBe(provider) + }) + }) + + it("should load API key from secrets", () => { + mockContextProxy.getSecret.mockImplementation((key: string) => { + if (key === "codebaseIndexRerankerApiKey") { + return "super-secret-api-key" + } + return "" + }) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + expect(config.apiKey).toBe("super-secret-api-key") + }) + }) + + describe("rerankerTopN getter", () => { + it("should return configured topN value", () => { + expect(configManager.rerankerTopN).toBe(100) + }) + + it("should return default topN when undefined", () => { + const undefinedState = { + ...mockGlobalState, + codebaseIndexRerankerTopN: undefined, + } + mockContextProxy.getGlobalState.mockReturnValue(undefinedState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.rerankerTopN).toBe(100) + }) + + it("should handle custom topN values", () => { + const customState = { + ...mockGlobalState, + codebaseIndexRerankerTopN: 250, + } + mockContextProxy.getGlobalState.mockReturnValue(customState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.rerankerTopN).toBe(250) + }) + }) + + describe("rerankerTopK getter", () => { + it("should return configured topK value", () => { + expect(configManager.rerankerTopK).toBe(20) + }) + + it("should return default topK when undefined", () => { + const undefinedState = { + ...mockGlobalState, + codebaseIndexRerankerTopK: undefined, + } + mockContextProxy.getGlobalState.mockReturnValue(undefinedState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.rerankerTopK).toBe(20) + }) + + it("should handle custom topK values", () => { + const customState = { + ...mockGlobalState, + codebaseIndexRerankerTopK: 50, + } + mockContextProxy.getGlobalState.mockReturnValue(customState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.rerankerTopK).toBe(50) + }) + }) + + describe("loadConfiguration with reranker settings", () => { + it("should load reranker configuration from storage", async () => { + const result = await configManager.loadConfiguration() + + expect(mockContextProxy.refreshSecrets).toHaveBeenCalled() + expect(mockContextProxy.getGlobalState).toHaveBeenCalledWith("codebaseIndexConfig") + + // Verify reranker config is loaded + const rerankerConfig = configManager.getRerankerConfig() + expect(rerankerConfig.enabled).toBe(true) + expect(rerankerConfig.provider).toBe("local") + expect(rerankerConfig.url).toBe("http://localhost:8080") + }) + + it("should handle missing reranker configuration gracefully", async () => { + mockContextProxy.getGlobalState.mockReturnValue(null) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const result = await configManager.loadConfiguration() + + // Should use defaults + const rerankerConfig = configManager.getRerankerConfig() + expect(rerankerConfig.enabled).toBe(false) + expect(rerankerConfig.provider).toBe("local") + expect(rerankerConfig.topN).toBe(100) + expect(rerankerConfig.topK).toBe(20) + expect(rerankerConfig.timeout).toBe(10000) + }) + }) + + describe("configuration validation", () => { + it("should validate reranker timeout values", () => { + const invalidTimeoutState = { + ...mockGlobalState, + codebaseIndexRerankerTimeout: -1000, + } + mockContextProxy.getGlobalState.mockReturnValue(invalidTimeoutState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + // Config manager doesn't validate negative timeout, just passes it through + expect(config.timeout).toBe(-1000) + }) + + it("should handle zero timeout value", () => { + const zeroTimeoutState = { + ...mockGlobalState, + codebaseIndexRerankerTimeout: 0, + } + mockContextProxy.getGlobalState.mockReturnValue(zeroTimeoutState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + expect(config.timeout).toBe(0) + }) + + it("should handle string provider values correctly", () => { + const stringProviderState = { + ...mockGlobalState, + codebaseIndexRerankerProvider: "LOCAL", // Wrong case + } + mockContextProxy.getGlobalState.mockReturnValue(stringProviderState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + // Config manager passes through as-is + expect(config.provider).toBe("LOCAL" as any) + }) + }) + + describe("integration with base config", () => { + it("should properly integrate reranker config with base code index config", () => { + const baseConfig = configManager.getConfig() + const rerankerConfig = configManager.getRerankerConfig() + + // Base config should not include reranker settings + expect(baseConfig).not.toHaveProperty("rerankerEnabled") + expect(baseConfig).not.toHaveProperty("rerankerProvider") + + // Reranker config should be separate + expect(rerankerConfig).toBeDefined() + expect(rerankerConfig.enabled).toBe(true) + }) + + it("should maintain consistency between feature enabled and reranker enabled", () => { + expect(configManager.isFeatureEnabled).toBe(true) + expect(configManager.isRerankerEnabled).toBe(true) + + // Disable feature + const disabledFeatureState = { + ...mockGlobalState, + codebaseIndexEnabled: false, + } + mockContextProxy.getGlobalState.mockReturnValue(disabledFeatureState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + expect(configManager.isFeatureEnabled).toBe(false) + expect(configManager.isRerankerEnabled).toBe(false) // Should also be false + }) + }) + + describe("reranker URL handling", () => { + it("should handle URLs with different protocols", () => { + const httpsState = { + ...mockGlobalState, + codebaseIndexRerankerUrl: "https://secure-reranker.com:8443", + } + mockContextProxy.getGlobalState.mockReturnValue(httpsState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + expect(config.url).toBe("https://secure-reranker.com:8443") + }) + + it("should handle empty URL string", () => { + const emptyUrlState = { + ...mockGlobalState, + codebaseIndexRerankerUrl: "", + } + mockContextProxy.getGlobalState.mockReturnValue(emptyUrlState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + expect(config.url).toBe("") + }) + }) + + describe("reranker model handling", () => { + it("should handle various model names", () => { + const models = [ + "ms-marco-MiniLM-L-6-v2", + "ms-marco-TinyBERT-L-2-v2", + "cross-encoder/ms-marco-electra-base", + "custom-model-v1", + ] + + models.forEach((model) => { + const modelState = { + ...mockGlobalState, + codebaseIndexRerankerModel: model, + } + mockContextProxy.getGlobalState.mockReturnValue(modelState) + const cm = new CodeIndexConfigManager(mockContextProxy as any) + + const config = cm.getRerankerConfig() + expect(config.model).toBe(model) + }) + }) + + it("should handle empty model string", () => { + const emptyModelState = { + ...mockGlobalState, + codebaseIndexRerankerModel: "", + } + mockContextProxy.getGlobalState.mockReturnValue(emptyModelState) + configManager = new CodeIndexConfigManager(mockContextProxy as any) + + const config = configManager.getRerankerConfig() + expect(config.model).toBe("") + }) + }) +}) diff --git a/src/tests/services/code-index/rerankers/factory.test.ts b/src/tests/services/code-index/rerankers/factory.test.ts new file mode 100644 index 0000000000..6633a400c0 --- /dev/null +++ b/src/tests/services/code-index/rerankers/factory.test.ts @@ -0,0 +1,465 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import { RerankerFactory } from "../../../../services/code-index/rerankers/factory" +import { LocalReranker } from "../../../../services/code-index/rerankers/local" +import { RerankerConfig } from "../../../../services/code-index/interfaces/reranker" + +// Mock the LocalReranker +vi.mock("../../../../services/code-index/rerankers/local") + +describe("RerankerFactory", () => { + const mockValidConfig: RerankerConfig = { + enabled: true, + provider: "local", + url: "http://localhost:8080", + apiKey: "test-api-key", + model: "test-model", + topN: 100, + topK: 20, + timeout: 30000, + } + + let consoleLogSpy: any + let consoleWarnSpy: any + let consoleErrorSpy: any + + beforeEach(() => { + vi.clearAllMocks() + consoleLogSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + consoleWarnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}) + consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + }) + + afterEach(() => { + consoleLogSpy.mockRestore() + consoleWarnSpy.mockRestore() + consoleErrorSpy.mockRestore() + }) + + describe("create", () => { + it("should return undefined when reranking is disabled", async () => { + const disabledConfig = { ...mockValidConfig, enabled: false } + + const result = await RerankerFactory.create(disabledConfig) + + expect(result).toBeUndefined() + expect(consoleLogSpy).toHaveBeenCalledWith("Reranking is disabled in configuration") + }) + + it("should create local reranker successfully", async () => { + const mockReranker = { + validateConfiguration: vi.fn().mockResolvedValue({ valid: true }), + healthCheck: vi.fn().mockResolvedValue(true), + } + ;(LocalReranker as any).mockImplementation(() => mockReranker) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(LocalReranker).toHaveBeenCalledWith(mockValidConfig) + expect(mockReranker.validateConfiguration).toHaveBeenCalled() + expect(mockReranker.healthCheck).toHaveBeenCalled() + expect(result).toBe(mockReranker) + expect(consoleLogSpy).toHaveBeenCalledWith("Successfully created local reranker") + }) + + it("should return undefined when validation fails", async () => { + const mockReranker = { + validateConfiguration: vi.fn().mockResolvedValue({ + valid: false, + error: "Invalid configuration", + }), + healthCheck: vi.fn(), + } + ;(LocalReranker as any).mockImplementation(() => mockReranker) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBeUndefined() + expect(mockReranker.validateConfiguration).toHaveBeenCalled() + expect(mockReranker.healthCheck).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + "Reranker configuration validation failed: Invalid configuration", + ) + }) + + it("should warn but continue when health check fails", async () => { + const mockReranker = { + validateConfiguration: vi.fn().mockResolvedValue({ valid: true }), + healthCheck: vi.fn().mockResolvedValue(false), + } + ;(LocalReranker as any).mockImplementation(() => mockReranker) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBe(mockReranker) + expect(mockReranker.healthCheck).toHaveBeenCalled() + expect(consoleWarnSpy).toHaveBeenCalledWith( + "Reranker health check failed, but continuing with initialization", + ) + }) + + it("should return undefined for cohere provider (not implemented)", async () => { + const cohereConfig = { ...mockValidConfig, provider: "cohere" as const } + + const result = await RerankerFactory.create(cohereConfig) + + expect(result).toBeUndefined() + expect(consoleWarnSpy).toHaveBeenCalledWith("Cohere reranker not yet implemented") + }) + + it("should return undefined for openai provider (not implemented)", async () => { + const openaiConfig = { ...mockValidConfig, provider: "openai" as const } + + const result = await RerankerFactory.create(openaiConfig) + + expect(result).toBeUndefined() + expect(consoleWarnSpy).toHaveBeenCalledWith("OpenAI reranker not yet implemented") + }) + + it("should return undefined for custom provider (not implemented)", async () => { + const customConfig = { ...mockValidConfig, provider: "custom" as const } + + const result = await RerankerFactory.create(customConfig) + + expect(result).toBeUndefined() + expect(consoleWarnSpy).toHaveBeenCalledWith("Custom reranker not yet implemented") + }) + + it("should return undefined for unknown provider", async () => { + const unknownConfig = { ...mockValidConfig, provider: "unknown" as any } + + const result = await RerankerFactory.create(unknownConfig) + + expect(result).toBeUndefined() + expect(consoleErrorSpy).toHaveBeenCalledWith("Unknown reranker provider: unknown") + }) + + it("should handle constructor errors", async () => { + ;(LocalReranker as any).mockImplementation(() => { + throw new Error("Constructor error") + }) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBeUndefined() + expect(consoleErrorSpy).toHaveBeenCalledWith("Failed to create reranker: Constructor error") + }) + + it("should handle validation errors", async () => { + const mockReranker = { + validateConfiguration: vi.fn().mockRejectedValue(new Error("Validation error")), + healthCheck: vi.fn(), + } + ;(LocalReranker as any).mockImplementation(() => mockReranker) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBeUndefined() + expect(consoleErrorSpy).toHaveBeenCalledWith("Failed to create reranker: Validation error") + }) + + it("should handle health check errors", async () => { + const mockReranker = { + validateConfiguration: vi.fn().mockResolvedValue({ valid: true }), + healthCheck: vi.fn().mockRejectedValue(new Error("Health check error")), + } + ;(LocalReranker as any).mockImplementation(() => mockReranker) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBeUndefined() + expect(consoleErrorSpy).toHaveBeenCalledWith("Failed to create reranker: Health check error") + }) + + it("should handle non-Error exceptions", async () => { + ;(LocalReranker as any).mockImplementation(() => { + throw "String error" + }) + + const result = await RerankerFactory.create(mockValidConfig) + + expect(result).toBeUndefined() + expect(consoleErrorSpy).toHaveBeenCalledWith("Failed to create reranker: String error") + }) + }) + + describe("validateConfig", () => { + it("should validate valid local config", () => { + const result = RerankerFactory.validateConfig(mockValidConfig) + + expect(result).toEqual({ valid: true }) + }) + + it("should return error when provider is missing", () => { + const config = { ...mockValidConfig, provider: undefined as any } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Provider is required", + }) + }) + + it("should return valid when disabled", () => { + const config = { ...mockValidConfig, enabled: false } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ valid: true }) + }) + + describe("local provider validation", () => { + it("should require url for local provider", () => { + const config = { ...mockValidConfig, url: undefined } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Local reranker requires a URL", + }) + }) + + it("should require apiKey for local provider", () => { + const config = { ...mockValidConfig, apiKey: undefined } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Local reranker requires an API key", + }) + }) + }) + + describe("cohere provider validation", () => { + it("should require apiKey for cohere provider", () => { + const config = { + ...mockValidConfig, + provider: "cohere" as const, + apiKey: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Cohere reranker requires an API key", + }) + }) + + it("should validate cohere config with apiKey", () => { + const config = { + ...mockValidConfig, + provider: "cohere" as const, + apiKey: "cohere-key", + url: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ valid: true }) + }) + }) + + describe("openai provider validation", () => { + it("should require apiKey for openai provider", () => { + const config = { + ...mockValidConfig, + provider: "openai" as const, + apiKey: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "OpenAI reranker requires an API key", + }) + }) + + it("should validate openai config with apiKey", () => { + const config = { + ...mockValidConfig, + provider: "openai" as const, + apiKey: "openai-key", + url: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ valid: true }) + }) + }) + + describe("custom provider validation", () => { + it("should require url for custom provider", () => { + const config = { + ...mockValidConfig, + provider: "custom" as const, + url: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Custom reranker requires a URL", + }) + }) + + it("should validate custom config with url", () => { + const config = { + ...mockValidConfig, + provider: "custom" as const, + url: "http://custom-reranker.com", + apiKey: undefined, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ valid: true }) + }) + }) + + it("should return error for unknown provider", () => { + const config = { ...mockValidConfig, provider: "unknown" as any } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "Unknown provider: unknown", + }) + }) + + describe("numeric field validation", () => { + it("should validate topN must be greater than 0", () => { + const config = { ...mockValidConfig, topN: 0 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "topN must be greater than 0", + }) + }) + + it("should validate topN negative value", () => { + const config = { ...mockValidConfig, topN: -5 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "topN must be greater than 0", + }) + }) + + it("should validate topK must be greater than 0", () => { + const config = { ...mockValidConfig, topK: 0 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "topK must be greater than 0", + }) + }) + + it("should validate topK negative value", () => { + const config = { ...mockValidConfig, topK: -10 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "topK must be greater than 0", + }) + }) + + it("should validate topK cannot be greater than topN", () => { + const config = { ...mockValidConfig, topN: 50, topK: 100 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "topK cannot be greater than topN", + }) + }) + + it("should validate timeout must be greater than 0", () => { + const config = { ...mockValidConfig, timeout: 0 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "timeout must be greater than 0", + }) + }) + + it("should validate timeout negative value", () => { + const config = { ...mockValidConfig, timeout: -1000 } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ + valid: false, + error: "timeout must be greater than 0", + }) + }) + + it("should allow undefined numeric fields", () => { + const config = { + enabled: true, + provider: "local" as const, + url: "http://localhost", + apiKey: "key", + topN: 100, + topK: 20, + timeout: 10000, + } + + const result = RerankerFactory.validateConfig(config) + + expect(result).toEqual({ valid: true }) + }) + }) + }) + + describe("getSupportedProviders", () => { + it("should return all supported providers", () => { + const providers = RerankerFactory.getSupportedProviders() + + expect(providers).toEqual(["local", "cohere", "openai", "custom"]) + }) + }) + + describe("isProviderImplemented", () => { + it("should return true for local provider", () => { + expect(RerankerFactory.isProviderImplemented("local")).toBe(true) + }) + + it("should return false for cohere provider", () => { + expect(RerankerFactory.isProviderImplemented("cohere")).toBe(false) + }) + + it("should return false for openai provider", () => { + expect(RerankerFactory.isProviderImplemented("openai")).toBe(false) + }) + + it("should return false for custom provider", () => { + expect(RerankerFactory.isProviderImplemented("custom")).toBe(false) + }) + + it("should return false for unknown provider", () => { + expect(RerankerFactory.isProviderImplemented("unknown")).toBe(false) + }) + + it("should return false for empty string", () => { + expect(RerankerFactory.isProviderImplemented("")).toBe(false) + }) + }) +}) diff --git a/src/tests/services/code-index/rerankers/local.test.ts b/src/tests/services/code-index/rerankers/local.test.ts new file mode 100644 index 0000000000..531dd14286 --- /dev/null +++ b/src/tests/services/code-index/rerankers/local.test.ts @@ -0,0 +1,581 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import axios from "axios" +import { LocalReranker } from "../../../../services/code-index/rerankers/local" +import { RerankerConfig, RerankCandidate } from "../../../../services/code-index/interfaces/reranker" + +// Mock axios +vi.mock("axios") + +describe("LocalReranker", () => { + const mockConfig: RerankerConfig = { + enabled: true, + provider: "local", + url: "http://localhost:8080", + apiKey: "test-api-key", + model: "test-model", + topN: 100, + topK: 20, + timeout: 30000, + } + + let consoleLogSpy: any + let consoleErrorSpy: any + + beforeEach(() => { + vi.clearAllMocks() + consoleLogSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + }) + + afterEach(() => { + consoleLogSpy.mockRestore() + consoleErrorSpy.mockRestore() + }) + + describe("constructor", () => { + it("should create instance with valid config", () => { + const mockAxiosCreate = vi.fn().mockReturnValue({}) + ;(axios.create as any) = mockAxiosCreate + + const reranker = new LocalReranker(mockConfig) + + expect(reranker).toBeDefined() + expect(mockAxiosCreate).toHaveBeenCalledWith({ + baseURL: "http://localhost:8080", + timeout: 30000, + headers: { + Authorization: "Bearer test-api-key", + "Content-Type": "application/json", + }, + }) + }) + + it("should throw error when url is missing", () => { + const invalidConfig = { ...mockConfig, url: undefined } + + expect(() => new LocalReranker(invalidConfig)).toThrow("Local reranker requires a base URL") + }) + + it("should throw error when apiKey is missing", () => { + const invalidConfig = { ...mockConfig, apiKey: undefined } + + expect(() => new LocalReranker(invalidConfig)).toThrow("Local reranker requires an API key") + }) + + it("should remove trailing slash from url", () => { + const mockAxiosCreate = vi.fn().mockReturnValue({}) + ;(axios.create as any) = mockAxiosCreate + + const configWithTrailingSlash = { ...mockConfig, url: "http://localhost:8080/" } + new LocalReranker(configWithTrailingSlash) + + expect(mockAxiosCreate).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "http://localhost:8080", + }), + ) + }) + + it("should use default timeout when not specified", () => { + const mockAxiosCreate = vi.fn().mockReturnValue({}) + ;(axios.create as any) = mockAxiosCreate + + const { timeout, ...configWithoutTimeout } = mockConfig + new LocalReranker({ ...configWithoutTimeout, timeout: 30000 } as RerankerConfig) + + expect(mockAxiosCreate).toHaveBeenCalledWith( + expect.objectContaining({ + timeout: 30000, + }), + ) + }) + }) + + describe("rerank", () => { + let reranker: LocalReranker + let mockAxiosInstance: any + + beforeEach(() => { + mockAxiosInstance = { + post: vi.fn(), + } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + reranker = new LocalReranker(mockConfig) + }) + + it("should successfully rerank candidates", async () => { + const query = "test query" + const candidates: RerankCandidate[] = [ + { id: "1", content: "First document" }, + { id: "2", content: "Second document" }, + { id: "3", content: "Third document" }, + ] + + const mockResponse = { + data: [ + { score: 0.9, rank: 1 }, + { score: 0.7, rank: 2 }, + { score: 0.5, rank: 3 }, + ], + } + mockAxiosInstance.post.mockResolvedValueOnce(mockResponse) + + const results = await reranker.rerank(query, candidates) + + expect(mockAxiosInstance.post).toHaveBeenCalledWith("/rerank", { + query, + documents: ["First document", "Second document", "Third document"], + model: "test-model", + max_results: 20, + }) + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ id: "1", score: 0.9, rank: 1 }) + expect(results[1]).toEqual({ id: "2", score: 0.7, rank: 2 }) + expect(results[2]).toEqual({ id: "3", score: 0.5, rank: 3 }) + }) + + it("should return empty array for empty candidates", async () => { + const results = await reranker.rerank("test query", []) + + expect(results).toEqual([]) + expect(mockAxiosInstance.post).not.toHaveBeenCalled() + }) + + it("should throw error for empty query", async () => { + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("", candidates)).rejects.toThrow("Query cannot be empty") + await expect(reranker.rerank(" ", candidates)).rejects.toThrow("Query cannot be empty") + }) + + it("should limit candidates to topN", async () => { + const candidates: RerankCandidate[] = Array.from({ length: 150 }, (_, i) => ({ + id: String(i), + content: `Document ${i}`, + })) + + mockAxiosInstance.post.mockResolvedValueOnce({ data: [] }) + + await reranker.rerank("test query", candidates) + + const call = mockAxiosInstance.post.mock.calls[0] + expect(call[1].documents).toHaveLength(100) // topN + }) + + it("should limit results to maxResults parameter", async () => { + const candidates: RerankCandidate[] = Array.from({ length: 30 }, (_, i) => ({ + id: String(i), + content: `Document ${i}`, + })) + + const mockResponse = { + data: Array.from({ length: 30 }, (_, i) => ({ + score: 1 - i * 0.01, + rank: i + 1, + })), + } + mockAxiosInstance.post.mockResolvedValueOnce(mockResponse) + + const results = await reranker.rerank("test query", candidates, 10) + + expect(results).toHaveLength(10) + expect(mockAxiosInstance.post).toHaveBeenCalledWith( + "/rerank", + expect.objectContaining({ + max_results: 10, + }), + ) + }) + + it("should handle 401 authentication error", async () => { + const error = new Error("Unauthorized") + ;(error as any).response = { status: 401, data: "Invalid API key" } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker authentication failed: Invalid API key", + ) + }) + + it("should handle 404 endpoint not found error", async () => { + const error = new Error("Not Found") + ;(error as any).response = { status: 404, data: "Endpoint not found" } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker endpoint failed: Rerank endpoint not found at http://localhost:8080/rerank", + ) + }) + + it("should handle 429 rate limit error", async () => { + const error = new Error("Too Many Requests") + ;(error as any).response = { status: 429, data: "Rate limit exceeded" } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker rate-limit failed: Rate limit exceeded", + ) + }) + + it("should handle 500 server error", async () => { + const error = new Error("Internal Server Error") + ;(error as any).response = { status: 500, data: { error: "Server error" } } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + 'local reranker rerank failed: API error (500): {"error":"Server error"}', + ) + }) + + it("should handle timeout/no response error", async () => { + const error = new Error("Timeout") + ;(error as any).request = {} + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker connection failed: No response from reranker API at http://localhost:8080", + ) + }) + + it("should handle invalid response format - not an array", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ data: { invalid: "response" } }) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(false) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker rerank failed: Invalid response format from reranker API", + ) + }) + + it("should handle invalid response format - null data", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ data: null }) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(false) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker rerank failed: Invalid response format from reranker API", + ) + }) + + it("should handle missing candidate for index", async () => { + const candidates: RerankCandidate[] = [{ id: "1", content: "First" }] + + // Response has more items than candidates + const mockResponse = { + data: [ + { score: 0.9, rank: 1 }, + { score: 0.7, rank: 2 }, + ], + } + mockAxiosInstance.post.mockResolvedValueOnce(mockResponse) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(false) + + await expect(reranker.rerank("test query", candidates)).rejects.toThrow( + "local reranker rerank failed: No candidate found for index 1", + ) + }) + + it("should handle candidates without model in config", async () => { + const configWithoutModel = { ...mockConfig, model: undefined } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + const rerankerNoModel = new LocalReranker(configWithoutModel) + + const candidates: RerankCandidate[] = [{ id: "1", content: "Test" }] + mockAxiosInstance.post.mockResolvedValueOnce({ data: [] }) + + await rerankerNoModel.rerank("test query", candidates) + + const payload = mockAxiosInstance.post.mock.calls[0][1] + expect(payload).not.toHaveProperty("model") + }) + + it("should properly sort and assign ranks", async () => { + const candidates: RerankCandidate[] = [ + { id: "1", content: "First" }, + { id: "2", content: "Second" }, + { id: "3", content: "Third" }, + ] + + // Response with unsorted scores + const mockResponse = { + data: [ + { score: 0.5, rank: 99 }, // Will be re-ranked + { score: 0.9, rank: 99 }, + { score: 0.7, rank: 99 }, + ], + } + mockAxiosInstance.post.mockResolvedValueOnce(mockResponse) + + const results = await reranker.rerank("test query", candidates) + + // Should be sorted by score descending with correct ranks + expect(results[0]).toEqual({ id: "2", score: 0.9, rank: 1 }) + expect(results[1]).toEqual({ id: "3", score: 0.7, rank: 2 }) + expect(results[2]).toEqual({ id: "1", score: 0.5, rank: 3 }) + }) + }) + + describe("validateConfiguration", () => { + let reranker: LocalReranker + let mockAxiosInstance: any + + beforeEach(() => { + mockAxiosInstance = { + post: vi.fn(), + } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + reranker = new LocalReranker(mockConfig) + }) + + it("should validate successfully with valid response", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + data: [{ score: 0.5, rank: 1 }], + }) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ valid: true }) + expect(mockAxiosInstance.post).toHaveBeenCalledWith("/rerank", { + query: "test", + documents: ["test document"], + max_results: 1, + model: "test-model", + }) + }) + + it("should validate successfully with empty response array", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ data: [] }) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ valid: true }) + }) + + it("should fail validation for invalid response format", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ data: { invalid: "format" } }) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "Invalid response format from reranker API", + }) + }) + + it("should fail validation for missing score field", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + data: [{ rank: 1 }], // Missing score + }) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: 'Reranker API response missing required "score" field', + }) + }) + + it("should handle 401 authentication error", async () => { + const error = new Error("Unauthorized") + ;(error as any).response = { status: 401 } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "Invalid API key", + }) + }) + + it("should handle 404 endpoint not found", async () => { + const error = new Error("Not Found") + ;(error as any).response = { status: 404 } + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "Rerank endpoint not found at http://localhost:8080/rerank", + }) + }) + + it("should handle connection error", async () => { + const error = new Error("Connection error") + ;(error as any).request = {} + ;(error as any).isAxiosError = true + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(true) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "Cannot connect to reranker API at http://localhost:8080", + }) + }) + + it("should handle generic errors", async () => { + const error = new Error("Generic error") + mockAxiosInstance.post.mockRejectedValueOnce(error) + ;(axios.isAxiosError as any) = vi.fn().mockReturnValue(false) + + const result = await reranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "Configuration validation failed: Generic error", + }) + }) + + it("should fail common config validation", async () => { + const invalidConfig = { ...mockConfig, topK: 0 } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + const invalidReranker = new LocalReranker(invalidConfig) + + const result = await invalidReranker.validateConfiguration() + + expect(result).toEqual({ + valid: false, + error: "topK must be greater than 0", + }) + }) + + it("should validate config without model", async () => { + const configWithoutModel = { ...mockConfig, model: undefined } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + const rerankerNoModel = new LocalReranker(configWithoutModel) + + mockAxiosInstance.post.mockResolvedValueOnce({ data: [] }) + + await rerankerNoModel.validateConfiguration() + + const payload = mockAxiosInstance.post.mock.calls[0][1] + expect(payload).not.toHaveProperty("model") + }) + }) + + describe("healthCheck", () => { + let reranker: LocalReranker + let mockAxiosInstance: any + + beforeEach(() => { + mockAxiosInstance = { + post: vi.fn(), + } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + reranker = new LocalReranker(mockConfig) + }) + + it("should return true for successful health check", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + status: 200, + data: [], + }) + + const result = await reranker.healthCheck() + + expect(result).toBe(true) + expect(mockAxiosInstance.post).toHaveBeenCalledWith( + "/rerank", + { + query: "health check", + documents: ["test"], + max_results: 1, + model: "test-model", + }, + { + timeout: 5000, + }, + ) + }) + + it("should return false for non-200 status", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + status: 201, + data: [], + }) + + const result = await reranker.healthCheck() + + expect(result).toBe(false) + }) + + it("should return false for non-array response", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + status: 200, + data: { invalid: "response" }, + }) + + const result = await reranker.healthCheck() + + expect(result).toBe(false) + }) + + it("should return false on error", async () => { + mockAxiosInstance.post.mockRejectedValueOnce(new Error("Network error")) + + const result = await reranker.healthCheck() + + expect(result).toBe(false) + expect(consoleErrorSpy).toHaveBeenCalledWith("Health check failed:", expect.any(Error)) + }) + + it("should use 5 second timeout", async () => { + mockAxiosInstance.post.mockResolvedValueOnce({ + status: 200, + data: [], + }) + + await reranker.healthCheck() + + expect(mockAxiosInstance.post).toHaveBeenCalledWith(expect.any(String), expect.any(Object), { + timeout: 5000, + }) + }) + + it("should not include model if not configured", async () => { + const configWithoutModel = { ...mockConfig, model: undefined } + ;(axios.create as any) = vi.fn().mockReturnValue(mockAxiosInstance) + const rerankerNoModel = new LocalReranker(configWithoutModel) + + mockAxiosInstance.post.mockResolvedValueOnce({ + status: 200, + data: [], + }) + + await rerankerNoModel.healthCheck() + + const payload = mockAxiosInstance.post.mock.calls[0][1] + expect(payload).not.toHaveProperty("model") + }) + }) +}) diff --git a/src/tests/services/code-index/search-service-reranking.test.ts b/src/tests/services/code-index/search-service-reranking.test.ts new file mode 100644 index 0000000000..8af6525f75 --- /dev/null +++ b/src/tests/services/code-index/search-service-reranking.test.ts @@ -0,0 +1,516 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import { CodeIndexSearchService } from "../../../services/code-index/search-service" +import { CodeIndexConfigManager } from "../../../services/code-index/config-manager" +import { CodeIndexStateManager } from "../../../services/code-index/state-manager" +import { IEmbedder } from "../../../services/code-index/interfaces/embedder" +import { IVectorStore, VectorStoreSearchResult } from "../../../services/code-index/interfaces/vector-store" +import { IReranker, RerankCandidate, RerankResult } from "../../../services/code-index/interfaces/reranker" +import { TelemetryService } from "@roo-code/telemetry" +import { TelemetryEventName } from "@roo-code/types" + +// Mock dependencies +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureEvent: vi.fn(), + }, + }, +})) + +describe("CodeIndexSearchService - Reranking Integration", () => { + let searchService: CodeIndexSearchService + let mockConfigManager: any + let mockStateManager: any + let mockEmbedder: any + let mockVectorStore: any + let mockReranker: any + let consoleLogSpy: any + let consoleErrorSpy: any + + // Sample data + const mockSearchResults: VectorStoreSearchResult[] = [ + { + id: 1, + score: 0.7, + payload: { + filePath: "/src/file1.ts", + startLine: 10, + endLine: 20, + codeChunk: 'function test1() { return "test1"; }', + }, + }, + { + id: 2, + score: 0.6, + payload: { + filePath: "/src/file2.ts", + startLine: 30, + endLine: 40, + codeChunk: 'const value = "test2";', + }, + }, + { + id: 3, + score: 0.5, + payload: { + filePath: "/src/file3.ts", + startLine: 50, + endLine: 60, + codeChunk: "class TestClass { constructor() {} }", + }, + }, + ] + + const mockRerankedResults: RerankResult[] = [ + { id: "2", score: 0.9 }, // Second result becomes first + { id: "3", score: 0.8 }, // Third result becomes second + { id: "1", score: 0.4 }, // First result becomes third + ] + + beforeEach(() => { + vi.clearAllMocks() + + // Mock console + consoleLogSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + // Mock ConfigManager + mockConfigManager = { + isFeatureEnabled: true, + isFeatureConfigured: true, + isRerankerEnabled: true, + currentSearchMinScore: 0.3, + currentSearchMaxResults: 20, + rerankerTopN: 100, + rerankerTopK: 20, + } + + // Mock StateManager + mockStateManager = { + getCurrentStatus: vi.fn().mockReturnValue({ + systemStatus: "Indexed", + }), + setSystemState: vi.fn(), + } + + // Mock Embedder + mockEmbedder = { + createEmbeddings: vi.fn().mockResolvedValue({ + embeddings: [[0.1, 0.2, 0.3, 0.4, 0.5]], + }), + } + + // Mock VectorStore + mockVectorStore = { + search: vi.fn().mockResolvedValue(mockSearchResults), + } + + // Mock Reranker + mockReranker = { + rerank: vi.fn().mockResolvedValue(mockRerankedResults), + validateConfiguration: vi.fn().mockResolvedValue({ valid: true }), + healthCheck: vi.fn().mockResolvedValue(true), + } + }) + + afterEach(() => { + consoleLogSpy.mockRestore() + consoleErrorSpy.mockRestore() + }) + + describe("Search with reranking enabled", () => { + beforeEach(() => { + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + }) + + it("should successfully rerank results when reranking is enabled", async () => { + const query = "test query" + + const results = await searchService.searchIndex(query) + + // Verify embedder was called + expect(mockEmbedder.createEmbeddings).toHaveBeenCalledWith([query]) + + // Verify vector store was called with topN limit + expect(mockVectorStore.search).toHaveBeenCalledWith( + [0.1, 0.2, 0.3, 0.4, 0.5], + undefined, + 0.3, + 100, // topN for reranking + ) + + // Verify reranker was called with correct candidates + expect(mockReranker.rerank).toHaveBeenCalledWith( + query, + expect.arrayContaining([ + expect.objectContaining({ + id: "1", + content: 'function test1() { return "test1"; }', + }), + expect.objectContaining({ + id: "2", + content: 'const value = "test2";', + }), + expect.objectContaining({ + id: "3", + content: "class TestClass { constructor() {} }", + }), + ]), + 20, // topK + ) + + // Verify results are reordered according to reranking + expect(results).toHaveLength(3) + expect(results[0].id).toBe(2) // ID 2 is now first + expect(results[0].score).toBe(0.9) // With new score + expect(results[1].id).toBe(3) + expect(results[1].score).toBe(0.8) + expect(results[2].id).toBe(1) + expect(results[2].score).toBe(0.4) + + // Verify payload is preserved + expect(results[0].payload?.filePath).toBe("/src/file2.ts") + expect(results[1].payload?.filePath).toBe("/src/file3.ts") + expect(results[2].payload?.filePath).toBe("/src/file1.ts") + + // Verify performance logging + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringContaining("[CodeIndexSearchService] Reranking completed"), + ) + }) + + it("should fall back to vector results when reranking fails", async () => { + mockReranker.rerank.mockRejectedValueOnce(new Error("Reranking API error")) + + const results = await searchService.searchIndex("test query") + + // Should return original results limited to topK + expect(results).toHaveLength(3) + expect(results[0].id).toBe(1) // Original order + expect(results[0].score).toBe(0.7) // Original score + expect(results[1].id).toBe(2) + expect(results[2].id).toBe(3) + + // Verify error logging (logger uses console.log by default) + expect(consoleLogSpy).toHaveBeenCalledWith( + "[CodeIndexSearchService] Reranking failed, falling back to vector search results:", + expect.any(Error), + ) + + // Verify telemetry + expect(TelemetryService.instance.captureEvent).toHaveBeenCalledWith( + TelemetryEventName.CODE_INDEX_ERROR, + expect.objectContaining({ + error: "Reranking API error", + location: "searchIndex-reranking", + }), + ) + }) + + it("should respect topK limit when reranking returns more results", async () => { + // Create many search results + const manyResults = Array.from({ length: 50 }, (_, i) => ({ + id: i + 1, + score: 0.9 - i * 0.01, + payload: { + filePath: `/src/file${i + 1}.ts`, + startLine: i * 10, + endLine: i * 10 + 10, + codeChunk: `code chunk ${i + 1}`, + }, + })) + mockVectorStore.search.mockResolvedValueOnce(manyResults) + + // Mock reranker to return all results + const manyRerankedResults = manyResults.map((r, i) => ({ + id: r.id.toString(), + score: 0.99 - i * 0.01, + rank: i + 1, + })) + mockReranker.rerank.mockResolvedValueOnce(manyRerankedResults.slice(0, 20)) + + const results = await searchService.searchIndex("test query") + + // Should be limited to topK + expect(results).toHaveLength(20) + expect(mockReranker.rerank).toHaveBeenCalledWith("test query", expect.any(Array), 20) + }) + + it("should handle empty vector search results", async () => { + mockVectorStore.search.mockResolvedValueOnce([]) + + const results = await searchService.searchIndex("test query") + + expect(results).toHaveLength(0) + expect(mockReranker.rerank).not.toHaveBeenCalled() + + // Should log vector search completion, not reranking + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringContaining("[CodeIndexSearchService] Vector search completed"), + ) + }) + + it("should properly map metadata in reranker candidates", async () => { + await searchService.searchIndex("test query") + + const candidates = mockReranker.rerank.mock.calls[0][1] as RerankCandidate[] + + expect(candidates[0].metadata).toEqual({ + filePath: "/src/file1.ts", + startLine: 10, + endLine: 20, + score: 0.7, + }) + }) + + it("should handle results with missing payload gracefully", async () => { + const resultsWithMissingPayload: VectorStoreSearchResult[] = [ + { id: 1, score: 0.7, payload: undefined }, + { + id: 2, + score: 0.6, + payload: { filePath: "/src/file2.ts", startLine: 30, endLine: 40, codeChunk: "test code" }, + }, + ] + mockVectorStore.search.mockResolvedValueOnce(resultsWithMissingPayload) + mockReranker.rerank.mockResolvedValueOnce([ + { id: "2", score: 0.9, rank: 1 }, + { id: "1", score: 0.8, rank: 2 }, + ]) + + const results = await searchService.searchIndex("test query") + + const candidates = mockReranker.rerank.mock.calls[0][1] as RerankCandidate[] + expect(candidates[0].content).toBe("") // Empty string for missing payload + expect(candidates[1].content).toBe("test code") + + expect(results).toHaveLength(2) + }) + + it("should use directory prefix in vector search", async () => { + const directoryPrefix = "/src/components" + + await searchService.searchIndex("test query", directoryPrefix) + + expect(mockVectorStore.search).toHaveBeenCalledWith(expect.any(Array), directoryPrefix, 0.3, 100) + }) + }) + + describe("Search with reranking disabled", () => { + it("should skip reranking when isRerankerEnabled is false", async () => { + mockConfigManager.isRerankerEnabled = false + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + const results = await searchService.searchIndex("test query") + + // Should use currentSearchMaxResults instead of topN + expect(mockVectorStore.search).toHaveBeenCalledWith( + expect.any(Array), + undefined, + 0.3, + 20, // currentSearchMaxResults + ) + + // Reranker should not be called + expect(mockReranker.rerank).not.toHaveBeenCalled() + + // Results should be in original order + expect(results).toEqual(mockSearchResults) + + // Should log vector search completion + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringContaining("[CodeIndexSearchService] Vector search completed"), + ) + }) + + it("should skip reranking when reranker is not provided", async () => { + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + undefined, // No reranker + ) + + const results = await searchService.searchIndex("test query") + + expect(mockVectorStore.search).toHaveBeenCalledWith( + expect.any(Array), + undefined, + 0.3, + 20, // currentSearchMaxResults + ) + expect(results).toEqual(mockSearchResults) + }) + }) + + describe("Error handling", () => { + it("should throw error when feature is disabled", async () => { + mockConfigManager.isFeatureEnabled = false + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + await expect(searchService.searchIndex("test")).rejects.toThrow( + "Code index feature is disabled or not configured.", + ) + }) + + it("should throw error when feature is not configured", async () => { + mockConfigManager.isFeatureConfigured = false + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + await expect(searchService.searchIndex("test")).rejects.toThrow( + "Code index feature is disabled or not configured.", + ) + }) + + it("should throw error when index is not ready", async () => { + mockStateManager.getCurrentStatus.mockReturnValue({ + systemStatus: "NotIndexed", + }) + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + await expect(searchService.searchIndex("test")).rejects.toThrow( + "Code index is not ready for search. Current state: NotIndexed", + ) + }) + + it("should allow search during indexing state", async () => { + mockStateManager.getCurrentStatus.mockReturnValue({ + systemStatus: "Indexing", + }) + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + const results = await searchService.searchIndex("test query") + + expect(results).toHaveLength(3) + expect(mockReranker.rerank).toHaveBeenCalled() + }) + + it("should handle embedding generation failure", async () => { + mockEmbedder.createEmbeddings.mockResolvedValueOnce({ + embeddings: [], + }) + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + await expect(searchService.searchIndex("test")).rejects.toThrow("Failed to generate embedding for query.") + + expect(mockStateManager.setSystemState).toHaveBeenCalledWith( + "Error", + "Search failed: Failed to generate embedding for query.", + ) + + expect(TelemetryService.instance.captureEvent).toHaveBeenCalledWith( + TelemetryEventName.CODE_INDEX_ERROR, + expect.objectContaining({ + error: "Failed to generate embedding for query.", + location: "searchIndex", + }), + ) + }) + + it("should handle vector store search failure", async () => { + mockVectorStore.search.mockRejectedValueOnce(new Error("Vector store error")) + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + + await expect(searchService.searchIndex("test")).rejects.toThrow("Vector store error") + + expect(consoleLogSpy).toHaveBeenCalledWith( + "[CodeIndexSearchService] Error during search:", + expect.any(Error), + ) + }) + }) + + describe("Performance logging", () => { + beforeEach(() => { + searchService = new CodeIndexSearchService( + mockConfigManager as any, + mockStateManager as any, + mockEmbedder as any, + mockVectorStore as any, + mockReranker as any, + ) + }) + + it("should log performance metrics for successful reranking", async () => { + // Mock delays to test timing + mockVectorStore.search.mockImplementation(async () => { + await new Promise((resolve) => setTimeout(resolve, 50)) + return mockSearchResults + }) + mockReranker.rerank.mockImplementation(async () => { + await new Promise((resolve) => setTimeout(resolve, 30)) + return mockRerankedResults + }) + + await searchService.searchIndex("test query") + + // Check reranking performance log + const rerankingLog = consoleLogSpy.mock.calls.find((call: any) => call[0]?.includes("Reranking completed")) + expect(rerankingLog).toBeTruthy() + expect(rerankingLog[0]).toMatch(/Reranking completed in \d+ms/) + expect(rerankingLog[0]).toContain("Input: 3, Output: 3") + }) + + it("should log vector search performance when reranking is disabled", async () => { + mockConfigManager.isRerankerEnabled = false + + await searchService.searchIndex("test query") + + const vectorSearchLog = consoleLogSpy.mock.calls.find((call: any[]) => + call[0]?.includes("Vector search completed"), + ) + expect(vectorSearchLog).toBeTruthy() + expect(vectorSearchLog[0]).toMatch(/Vector search completed in \d+ms/) + expect(vectorSearchLog[0]).toContain("Results: 3") + }) + }) +}) diff --git a/webview-ui/src/components/chat/CodeIndexPopover.tsx b/webview-ui/src/components/chat/CodeIndexPopover.tsx index c85aaf6ea5..c842e4064e 100644 --- a/webview-ui/src/components/chat/CodeIndexPopover.tsx +++ b/webview-ui/src/components/chat/CodeIndexPopover.tsx @@ -63,6 +63,15 @@ interface LocalCodeIndexSettings { codebaseIndexSearchMaxResults?: number codebaseIndexSearchMinScore?: number + // Reranker settings + codebaseIndexRerankerEnabled?: boolean + codebaseIndexRerankerProvider?: "local" + codebaseIndexRerankerUrl?: string + codebaseIndexRerankerModel?: string + codebaseIndexRerankerTopN?: number + codebaseIndexRerankerTopK?: number + codebaseIndexRerankerTimeout?: number + // Secret settings (start empty, will be loaded separately) codeIndexOpenAiKey?: string codeIndexQdrantApiKey?: string @@ -70,19 +79,46 @@ interface LocalCodeIndexSettings { codebaseIndexOpenAiCompatibleApiKey?: string codebaseIndexGeminiApiKey?: string codebaseIndexMistralApiKey?: string + codebaseIndexRerankerApiKey?: string } // Validation schema for codebase index settings -const createValidationSchema = (provider: EmbedderProvider, t: any) => { - const baseSchema = z.object({ +const createValidationSchema = ( + provider: EmbedderProvider, + rerankerEnabled: boolean, + rerankerProvider: string | undefined, + t: any, +) => { + let baseSchema = z.object({ codebaseIndexEnabled: z.boolean(), codebaseIndexQdrantUrl: z .string() .min(1, t("settings:codeIndex.validation.qdrantUrlRequired")) .url(t("settings:codeIndex.validation.invalidQdrantUrl")), codeIndexQdrantApiKey: z.string().optional(), + codebaseIndexRerankerEnabled: z.boolean().optional(), + codebaseIndexRerankerTopN: z.number().min(10).max(500).optional(), + codebaseIndexRerankerTopK: z.number().min(5).max(100).optional(), + codebaseIndexRerankerTimeout: z.number().min(1000).max(30000).optional(), }) + // Add reranker validation if enabled + if (rerankerEnabled && rerankerProvider) { + switch (rerankerProvider) { + case "local": + baseSchema = baseSchema.extend({ + codebaseIndexRerankerUrl: z + .string() + .min(1, t("settings:codeIndex.validation.rerankerUrlRequired")) + .url(t("settings:codeIndex.validation.invalidRerankerUrl")), + codebaseIndexRerankerModel: z + .string() + .min(1, t("settings:codeIndex.validation.rerankerModelRequired")), + }) + break + } + } + switch (provider) { case "openai": return baseSchema.extend({ @@ -151,6 +187,8 @@ export const CodeIndexPopover: React.FC = ({ const [open, setOpen] = useState(false) const [isAdvancedSettingsOpen, setIsAdvancedSettingsOpen] = useState(false) const [isSetupSettingsOpen, setIsSetupSettingsOpen] = useState(false) + const [isRerankerSettingsOpen, setIsRerankerSettingsOpen] = useState(false) + const [isAdvancedRerankerOpen, setIsAdvancedRerankerOpen] = useState(false) const [indexingStatus, setIndexingStatus] = useState(externalIndexingStatus) @@ -174,12 +212,20 @@ export const CodeIndexPopover: React.FC = ({ codebaseIndexEmbedderModelDimension: undefined, codebaseIndexSearchMaxResults: CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_RESULTS, codebaseIndexSearchMinScore: CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_MIN_SCORE, + codebaseIndexRerankerEnabled: false, + codebaseIndexRerankerProvider: "local", + codebaseIndexRerankerUrl: "", + codebaseIndexRerankerModel: "ms-marco-MiniLM-L-6-v2", + codebaseIndexRerankerTopN: 100, + codebaseIndexRerankerTopK: 20, + codebaseIndexRerankerTimeout: 10000, codeIndexOpenAiKey: "", codeIndexQdrantApiKey: "", codebaseIndexOpenAiCompatibleBaseUrl: "", codebaseIndexOpenAiCompatibleApiKey: "", codebaseIndexGeminiApiKey: "", codebaseIndexMistralApiKey: "", + codebaseIndexRerankerApiKey: "", }) // Initial settings state - stores the settings when popover opens @@ -208,12 +254,20 @@ export const CodeIndexPopover: React.FC = ({ codebaseIndexConfig.codebaseIndexSearchMaxResults ?? CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_RESULTS, codebaseIndexSearchMinScore: codebaseIndexConfig.codebaseIndexSearchMinScore ?? CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_MIN_SCORE, + codebaseIndexRerankerEnabled: codebaseIndexConfig.codebaseIndexRerankerEnabled ?? false, + codebaseIndexRerankerProvider: codebaseIndexConfig.codebaseIndexRerankerProvider === "local" ? "local" as const : undefined, + codebaseIndexRerankerUrl: codebaseIndexConfig.codebaseIndexRerankerUrl || "", + codebaseIndexRerankerModel: codebaseIndexConfig.codebaseIndexRerankerModel || "ms-marco-MiniLM-L-6-v2", + codebaseIndexRerankerTopN: codebaseIndexConfig.codebaseIndexRerankerTopN ?? 100, + codebaseIndexRerankerTopK: codebaseIndexConfig.codebaseIndexRerankerTopK ?? 20, + codebaseIndexRerankerTimeout: codebaseIndexConfig.codebaseIndexRerankerTimeout ?? 10000, codeIndexOpenAiKey: "", codeIndexQdrantApiKey: "", codebaseIndexOpenAiCompatibleBaseUrl: codebaseIndexConfig.codebaseIndexOpenAiCompatibleBaseUrl || "", codebaseIndexOpenAiCompatibleApiKey: "", codebaseIndexGeminiApiKey: "", codebaseIndexMistralApiKey: "", + codebaseIndexRerankerApiKey: "", } setInitialSettings(settings) setCurrentSettings(settings) @@ -308,6 +362,9 @@ export const CodeIndexPopover: React.FC = ({ if (!prev.codebaseIndexMistralApiKey || prev.codebaseIndexMistralApiKey === SECRET_PLACEHOLDER) { updated.codebaseIndexMistralApiKey = secretStatus.hasMistralApiKey ? SECRET_PLACEHOLDER : "" } + if (!prev.codebaseIndexRerankerApiKey || prev.codebaseIndexRerankerApiKey === SECRET_PLACEHOLDER) { + updated.codebaseIndexRerankerApiKey = secretStatus.hasRerankerApiKey ? SECRET_PLACEHOLDER : "" + } return updated } @@ -368,7 +425,12 @@ export const CodeIndexPopover: React.FC = ({ // Validation function const validateSettings = (): boolean => { - const schema = createValidationSchema(currentSettings.codebaseIndexEmbedderProvider, t) + const schema = createValidationSchema( + currentSettings.codebaseIndexEmbedderProvider, + currentSettings.codebaseIndexRerankerEnabled || false, + currentSettings.codebaseIndexRerankerProvider, + t, + ) // Prepare data for validation const dataToValidate: any = {} @@ -380,7 +442,8 @@ export const CodeIndexPopover: React.FC = ({ key === "codeIndexOpenAiKey" || key === "codebaseIndexOpenAiCompatibleApiKey" || key === "codebaseIndexGeminiApiKey" || - key === "codebaseIndexMistralApiKey" + key === "codebaseIndexMistralApiKey" || + key === "codebaseIndexRerankerApiKey" ) { dataToValidate[key] = "placeholder-valid" } @@ -1072,6 +1135,286 @@ export const CodeIndexPopover: React.FC = ({ )} + {/* Reranker Settings Disclosure */} +
+ + + {isRerankerSettingsOpen && ( +
+ {/* Enable Reranking */} +
+
+ + updateSetting("codebaseIndexRerankerEnabled", e.target.checked) + }> + + {t("settings:codeIndex.enableRerankerLabel")} + + + + + +
+
+ + {currentSettings.codebaseIndexRerankerEnabled && ( + <> + {/* Reranker Provider Dropdown */} +
+ + +
+ + {/* Model Input - Available for all providers */} +
+
+ + + + +
+ + updateSetting("codebaseIndexRerankerModel", e.target.value) + } + placeholder="e.g., ms-marco-MiniLM-L-6-v2" + className="w-full" + /> +

+ {t("settings:codeIndex.rerankerModelHelperText")} +

+
+ + {/* Provider-specific settings */} + {currentSettings.codebaseIndexRerankerProvider === "local" && ( + <> + {/* Reranker URL */} +
+ + + updateSetting( + "codebaseIndexRerankerUrl", + e.target.value, + ) + } + placeholder="http://localhost:8080" + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexRerankerUrl, + })} + /> + {formErrors.codebaseIndexRerankerUrl && ( +

+ {formErrors.codebaseIndexRerankerUrl} +

+ )} +
+ + {/* API Key for local reranker (optional) */} +
+ + + updateSetting( + "codebaseIndexRerankerApiKey", + e.target.value, + ) + } + placeholder={t( + "settings:codeIndex.rerankerApiKeyPlaceholder", + )} + className={cn("w-full", { + "border-red-500": + formErrors.codebaseIndexRerankerApiKey, + })} + /> + {formErrors.codebaseIndexRerankerApiKey && ( +

+ {formErrors.codebaseIndexRerankerApiKey} +

+ )} +
+ + )} + + {/* Advanced Reranking Parameters (collapsible) */} +
+ + + {isAdvancedRerankerOpen && ( +
+ {/* TopN - Candidates to rerank */} +
+
+ + + + +
+
+ + updateSetting( + "codebaseIndexRerankerTopN", + value[0], + ) + } + min={10} + max={500} + step={10} + className="flex-1" + /> + + {currentSettings.codebaseIndexRerankerTopN || 100} + +
+
+ + {/* TopK - Results to return */} +
+
+ + + + +
+
+ + updateSetting( + "codebaseIndexRerankerTopK", + value[0], + ) + } + min={5} + max={100} + step={5} + className="flex-1" + /> + + {currentSettings.codebaseIndexRerankerTopK || 20} + +
+
+ + {/* Timeout */} +
+
+ + + + +
+
+ + updateSetting( + "codebaseIndexRerankerTimeout", + value[0] * 1000, + ) + } + min={1} + max={30} + step={1} + className="flex-1" + /> + + {currentSettings.codebaseIndexRerankerTimeout + ? currentSettings.codebaseIndexRerankerTimeout / + 1000 + : 10} + s + +
+
+
+ )} +
+ + )} +
+ )} +
+ {/* Advanced Settings Disclosure */}