diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index c126aafe..7ef27671 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -54,6 +54,7 @@ jobs: - [ ] Unit tests for new functions/methods - [ ] Integration tests for new MCP tools - [ ] Test coverage for edge cases + - [ ] **100% test coverage maintained** (use `# pragma: no cover` only for truly hard-to-test code) - [ ] Documentation updated (README, docstrings) - [ ] CLAUDE.md updated if conventions change diff --git a/CLAUDE.md b/CLAUDE.md index a678842b..7da47184 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -117,17 +117,38 @@ counter += 1 # track retries for backoff calculation ### Codebase Architecture +See [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md) for detailed architecture documentation. + +**Directory Structure:** - `/alembic` - Alembic db migrations -- `/api` - FastAPI implementation of REST endpoints -- `/cli` - Typer command-line interface +- `/api` - FastAPI REST endpoints + `container.py` composition root +- `/cli` - Typer CLI + `container.py` composition root +- `/deps` - Feature-scoped FastAPI dependencies (config, db, projects, repositories, services, importers) - `/importers` - Import functionality for Claude, ChatGPT, and other sources - `/markdown` - Markdown parsing and processing -- `/mcp` - Model Context Protocol server implementation +- `/mcp` - MCP server + `container.py` composition root + `clients/` typed API clients - `/models` - SQLAlchemy ORM models - `/repository` - Data access layer - `/schemas` - Pydantic models for validation - `/services` - Business logic layer -- `/sync` - File synchronization services +- `/sync` - File synchronization services + `coordinator.py` for lifecycle management + +**Composition Roots:** +Each entrypoint (API, MCP, CLI) has a composition root that: +- Reads `ConfigManager` (the only place that reads global config) +- Resolves runtime mode via `RuntimeMode` enum (TEST > CLOUD > LOCAL) +- Provides dependencies to downstream code explicitly + +**Typed API Clients (MCP):** +MCP tools use typed clients in `mcp/clients/` to communicate with the API: +- `KnowledgeClient` - Entity CRUD operations +- `SearchClient` - Search operations +- `MemoryClient` - Context building +- `DirectoryClient` - Directory listing +- `ResourceClient` - Resource reading +- `ProjectClient` - Project management + +Flow: MCP Tool → Typed Client → HTTP API → Router → Service → Repository ### Development Notes @@ -146,6 +167,7 @@ counter += 1 # track retries for backoff calculation - CI runs SQLite and Postgres tests in parallel for faster feedback - Performance benchmarks are in `test-int/test_sync_performance_benchmark.py` - Use pytest markers: `@pytest.mark.benchmark` for benchmarks, `@pytest.mark.slow` for slow tests +- **Coverage must stay at 100%**: Write tests for new code. Only use `# pragma: no cover` when tests would require excessive mocking (e.g., TYPE_CHECKING blocks, error handlers that need failure injection, runtime-mode-dependent code paths) ### Async Client Pattern (Important!) diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 00000000..313ab94a --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,412 @@ +# Basic Memory Architecture + +This document describes the architectural patterns and composition structure of Basic Memory. + +## Overview + +Basic Memory is a local-first knowledge management system with three entrypoints: +- **API** - FastAPI REST server for HTTP access +- **MCP** - Model Context Protocol server for LLM integration +- **CLI** - Typer command-line interface + +Each entrypoint uses a **composition root** pattern to manage configuration and dependencies. + +## Composition Roots + +### What is a Composition Root? + +A composition root is the single place in an application where dependencies are wired together. In Basic Memory, each entrypoint has its own composition root that: + +1. Reads configuration from `ConfigManager` +2. Resolves runtime mode (cloud/local/test) +3. Creates and provides dependencies to downstream code + +**Key principle**: Only composition roots read global configuration. All other modules receive configuration explicitly. + +### Container Structure + +Each entrypoint has a container dataclass in its package: + +``` +src/basic_memory/ +├── api/ +│ └── container.py # ApiContainer +├── mcp/ +│ └── container.py # McpContainer +├── cli/ +│ └── container.py # CliContainer +└── runtime.py # RuntimeMode enum and resolver +``` + +### Container Pattern + +All containers follow the same structure: + +```python +@dataclass +class Container: + config: BasicMemoryConfig + mode: RuntimeMode + + @classmethod + def create(cls) -> "Container": + """Create container by reading ConfigManager.""" + config = ConfigManager().config + mode = resolve_runtime_mode( + cloud_mode_enabled=config.cloud_mode_enabled, + is_test_env=config.is_test_env, + ) + return cls(config=config, mode=mode) + + @property + def some_computed_property(self) -> bool: + """Derived values based on config and mode.""" + return self.mode.is_local and self.config.some_setting + +# Module-level singleton +_container: Container | None = None + +def get_container() -> Container: + if _container is None: + raise RuntimeError("Container not initialized") + return _container + +def set_container(container: Container) -> None: + global _container + _container = container +``` + +### Runtime Mode Resolution + +The `RuntimeMode` enum centralizes mode detection: + +```python +class RuntimeMode(Enum): + LOCAL = "local" + CLOUD = "cloud" + TEST = "test" + + @property + def is_cloud(self) -> bool: + return self == RuntimeMode.CLOUD + + @property + def is_local(self) -> bool: + return self == RuntimeMode.LOCAL + + @property + def is_test(self) -> bool: + return self == RuntimeMode.TEST +``` + +Resolution follows this precedence: **TEST > CLOUD > LOCAL** + +```python +def resolve_runtime_mode(cloud_mode_enabled: bool, is_test_env: bool) -> RuntimeMode: + if is_test_env: + return RuntimeMode.TEST + if cloud_mode_enabled: + return RuntimeMode.CLOUD + return RuntimeMode.LOCAL +``` + +## Dependencies Package + +### Structure + +The `deps/` package provides FastAPI dependencies organized by feature: + +``` +src/basic_memory/deps/ +├── __init__.py # Re-exports for backwards compatibility +├── config.py # Configuration access +├── db.py # Database/session management +├── projects.py # Project resolution +├── repositories.py # Data access layer +├── services.py # Business logic layer +└── importers.py # Import functionality +``` + +### Usage in Routers + +```python +from basic_memory.deps.services import get_entity_service +from basic_memory.deps.projects import get_project_config + +@router.get("/entities/{id}") +async def get_entity( + id: int, + entity_service: EntityService = Depends(get_entity_service), + project: ProjectConfig = Depends(get_project_config), +): + return await entity_service.get(id) +``` + +### Backwards Compatibility + +The old `deps.py` file still exists as a thin re-export shim: + +```python +# deps.py - backwards compatibility shim +from basic_memory.deps import * +``` + +New code should import from specific submodules (`basic_memory.deps.services`) for clarity. + +## MCP Tools Architecture + +### Typed API Clients + +MCP tools communicate with the API through typed clients that encapsulate HTTP paths and response validation: + +``` +src/basic_memory/mcp/clients/ +├── __init__.py # Re-exports all clients +├── base.py # BaseClient with common logic +├── knowledge.py # KnowledgeClient - entity CRUD +├── search.py # SearchClient - search operations +├── memory.py # MemoryClient - context building +├── directory.py # DirectoryClient - directory listing +├── resource.py # ResourceClient - resource reading +└── project.py # ProjectClient - project management +``` + +### Client Pattern + +Each client encapsulates API paths and validates responses: + +```python +class KnowledgeClient(BaseClient): + """Client for knowledge/entity operations.""" + + async def resolve_entity(self, identifier: str) -> int: + """Resolve identifier to entity ID.""" + response = await call_get( + self.http_client, + f"{self._base_path}/resolve/{identifier}", + ) + return int(response.text) + + async def get_entity(self, entity_id: int) -> EntityResponse: + """Get entity by ID.""" + response = await call_get( + self.http_client, + f"{self._base_path}/entities/{entity_id}", + ) + return EntityResponse.model_validate(response.json()) +``` + +### Tool → Client → API Flow + +``` +MCP Tool (thin adapter) + ↓ +Typed Client (encapsulates paths, validates responses) + ↓ +HTTP API (FastAPI router) + ↓ +Service Layer (business logic) + ↓ +Repository Layer (data access) +``` + +Example tool using typed client: + +```python +@mcp.tool() +async def search_notes(query: str, project: str | None = None) -> SearchResponse: + async with get_client() as client: + active_project = await get_active_project(client, project) + + # Import client inside function to avoid circular imports + from basic_memory.mcp.clients import SearchClient + + search_client = SearchClient(client, active_project.external_id) + return await search_client.search(query) +``` + +## Sync Coordination + +### SyncCoordinator + +The `SyncCoordinator` centralizes sync/watch lifecycle management: + +```python +@dataclass +class SyncCoordinator: + """Coordinates file sync and watch operations.""" + + status: SyncStatus = SyncStatus.NOT_STARTED + sync_task: asyncio.Task | None = None + watch_service: WatchService | None = None + + async def start(self, ...): + """Start sync and watch operations.""" + + async def stop(self): + """Stop all sync operations gracefully.""" + + def get_status_info(self) -> dict: + """Get current sync status for observability.""" +``` + +### Status Enum + +```python +class SyncStatus(Enum): + NOT_STARTED = "not_started" + STARTING = "starting" + RUNNING = "running" + STOPPING = "stopping" + STOPPED = "stopped" + ERROR = "error" +``` + +## Project Resolution + +### ProjectResolver + +Unified project selection across all entrypoints: + +```python +class ProjectResolver: + """Resolves which project to use based on context.""" + + def resolve( + self, + explicit_project: str | None = None, + ) -> ResolvedProject: + """Resolve project using three-tier hierarchy: + 1. Explicit project parameter + 2. Default project from config + 3. Single available project + """ +``` + +### Resolution Modes + +```python +class ResolutionMode(Enum): + EXPLICIT = "explicit" # User specified project + DEFAULT = "default" # Using configured default + SINGLE_PROJECT = "single" # Only one project exists + FALLBACK = "fallback" # Using first available +``` + +## Testing Patterns + +### Container Testing + +Each container has corresponding tests: + +``` +tests/ +├── api/test_api_container.py +├── mcp/test_mcp_container.py +└── cli/test_cli_container.py +``` + +Tests verify: +- Container creation from config +- Runtime mode properties +- Container accessor functions (get/set) + +### Mocking Typed Clients + +When testing MCP tools, mock at the client level: + +```python +def test_search_notes(monkeypatch): + import basic_memory.mcp.clients as clients_mod + + class MockSearchClient: + async def search(self, query): + return SearchResponse(results=[...]) + + monkeypatch.setattr(clients_mod, "SearchClient", MockSearchClient) +``` + +## Design Principles + +### 1. Explicit Dependencies + +Modules receive configuration explicitly rather than reading globals: + +```python +# Good - explicit injection +async def sync_files(config: BasicMemoryConfig): + ... + +# Avoid - hidden global access +async def sync_files(): + config = ConfigManager().config # Hidden coupling +``` + +### 2. Single Responsibility + +Each layer has a clear responsibility: +- **Containers**: Wire dependencies +- **Clients**: Encapsulate HTTP communication +- **Services**: Business logic +- **Repositories**: Data access +- **Tools/Routers**: Thin adapters + +### 3. Deferred Imports + +To avoid circular imports, typed clients are imported inside functions: + +```python +async def my_tool(): + async with get_client() as client: + # Import here to avoid circular dependency + from basic_memory.mcp.clients import KnowledgeClient + + knowledge_client = KnowledgeClient(client, project_id) +``` + +### 4. Backwards Compatibility + +When refactoring, maintain backwards compatibility via shims: + +```python +# Old module becomes a shim +from basic_memory.new_location import * + +# Docstring explains migration path +""" +DEPRECATED: Import from basic_memory.new_location instead. +This shim will be removed in a future version. +""" +``` + +## File Organization + +``` +src/basic_memory/ +├── api/ +│ ├── container.py # API composition root +│ ├── routers/ # FastAPI routers +│ └── ... +├── mcp/ +│ ├── container.py # MCP composition root +│ ├── clients/ # Typed API clients +│ ├── tools/ # MCP tool definitions +│ └── server.py # MCP server setup +├── cli/ +│ ├── container.py # CLI composition root +│ ├── app.py # Typer app +│ └── commands/ # CLI command groups +├── deps/ +│ ├── config.py # Config dependencies +│ ├── db.py # Database dependencies +│ ├── projects.py # Project dependencies +│ ├── repositories.py # Repository dependencies +│ ├── services.py # Service dependencies +│ └── importers.py # Importer dependencies +├── sync/ +│ ├── coordinator.py # SyncCoordinator +│ └── ... +├── runtime.py # RuntimeMode resolution +├── project_resolver.py # Unified project selection +└── config.py # Configuration management +``` diff --git a/src/basic_memory/api/app.py b/src/basic_memory/api/app.py index c548e9e4..455d9a5a 100644 --- a/src/basic_memory/api/app.py +++ b/src/basic_memory/api/app.py @@ -1,6 +1,5 @@ """FastAPI application for basic-memory knowledge graph API.""" -import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -8,7 +7,7 @@ from loguru import logger from basic_memory import __version__ as version -from basic_memory import db +from basic_memory.api.container import ApiContainer, set_container from basic_memory.api.routers import ( directory_router, importer_router, @@ -30,8 +29,8 @@ prompt_router as v2_prompt, importer_router as v2_importer, ) -from basic_memory.config import ConfigManager, init_api_logging -from basic_memory.services.initialization import initialize_file_sync, initialize_app +from basic_memory.config import init_api_logging +from basic_memory.services.initialization import initialize_app @asynccontextmanager @@ -41,47 +40,36 @@ async def lifespan(app: FastAPI): # pragma: no cover # Initialize logging for API (stdout in cloud mode, file otherwise) init_api_logging() - app_config = ConfigManager().config - logger.info("Starting Basic Memory API") + # --- Composition Root --- + # Create container and read config (single point of config access) + container = ApiContainer.create() + set_container(container) + app.state.container = container - await initialize_app(app_config) + logger.info(f"Starting Basic Memory API (mode={container.mode.name})") + + await initialize_app(container.config) # Cache database connections in app state for performance logger.info("Initializing database and caching connections...") - engine, session_maker = await db.get_or_create_db(app_config.database_path) + engine, session_maker = await container.init_database() app.state.engine = engine app.state.session_maker = session_maker logger.info("Database connections cached in app state") - # Start file sync if enabled - if app_config.sync_changes and not app_config.is_test_env: - logger.info(f"Sync changes enabled: {app_config.sync_changes}") - - # start file sync task in background - async def _file_sync_runner() -> None: - await initialize_file_sync(app_config) + # Create and start sync coordinator (lifecycle centralized in coordinator) + sync_coordinator = container.create_sync_coordinator() + await sync_coordinator.start() + app.state.sync_coordinator = sync_coordinator - app.state.sync_task = asyncio.create_task(_file_sync_runner()) - else: - if app_config.is_test_env: - logger.info("Test environment detected. Skipping file sync service.") - else: - logger.info("Sync changes disabled. Skipping file sync service.") - app.state.sync_task = None - - # proceed with startup + # Proceed with startup yield + # Shutdown - coordinator handles clean task cancellation logger.info("Shutting down Basic Memory API") - if app.state.sync_task: - logger.info("Stopping sync...") - app.state.sync_task.cancel() # pyright: ignore - try: - await app.state.sync_task - except asyncio.CancelledError: - logger.info("Sync task cancelled successfully") - - await db.shutdown_db() + await sync_coordinator.stop() + + await container.shutdown_database() # Initialize FastAPI app diff --git a/src/basic_memory/api/container.py b/src/basic_memory/api/container.py new file mode 100644 index 00000000..a333f30f --- /dev/null +++ b/src/basic_memory/api/container.py @@ -0,0 +1,133 @@ +"""API composition root for Basic Memory. + +This container owns reading ConfigManager and environment variables for the +API entrypoint. Downstream modules receive config/dependencies explicitly +rather than reading globals. + +Design principles: +- Only this module reads ConfigManager directly +- Runtime mode (cloud/local/test) is resolved here +- Factories for services are provided, not singletons +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, AsyncSession + +from basic_memory import db +from basic_memory.config import BasicMemoryConfig, ConfigManager +from basic_memory.runtime import RuntimeMode, resolve_runtime_mode + +if TYPE_CHECKING: # pragma: no cover + from basic_memory.sync import SyncCoordinator + + +@dataclass +class ApiContainer: + """Composition root for the API entrypoint. + + Holds resolved configuration and runtime context. + Created once at app startup, then used to wire dependencies. + """ + + config: BasicMemoryConfig + mode: RuntimeMode + + # --- Database --- + # Cached database connections (set during lifespan startup) + engine: AsyncEngine | None = None + session_maker: async_sessionmaker[AsyncSession] | None = None + + @classmethod + def create(cls) -> "ApiContainer": # pragma: no cover + """Create container by reading ConfigManager. + + This is the single point where API reads global config. + """ + config = ConfigManager().config + mode = resolve_runtime_mode( + cloud_mode_enabled=config.cloud_mode_enabled, + is_test_env=config.is_test_env, + ) + return cls(config=config, mode=mode) + + # --- Runtime Mode Properties --- + + @property + def should_sync_files(self) -> bool: + """Whether file sync should be started. + + Sync is enabled when: + - sync_changes is True in config + - Not in test mode (tests manage their own sync) + """ + return self.config.sync_changes and not self.mode.is_test + + @property + def sync_skip_reason(self) -> str | None: # pragma: no cover + """Reason why sync is skipped, or None if sync should run. + + Useful for logging why sync was disabled. + """ + if self.mode.is_test: + return "Test environment detected" + if not self.config.sync_changes: + return "Sync changes disabled" + return None + + def create_sync_coordinator(self) -> "SyncCoordinator": # pragma: no cover + """Create a SyncCoordinator with this container's settings. + + Returns: + SyncCoordinator configured for this runtime environment + """ + # Deferred import to avoid circular dependency + from basic_memory.sync import SyncCoordinator + + return SyncCoordinator( + config=self.config, + should_sync=self.should_sync_files, + skip_reason=self.sync_skip_reason, + ) + + # --- Database Factory --- + + async def init_database( # pragma: no cover + self, + ) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: + """Initialize and cache database connections. + + Returns: + Tuple of (engine, session_maker) + """ + engine, session_maker = await db.get_or_create_db(self.config.database_path) + self.engine = engine + self.session_maker = session_maker + return engine, session_maker + + async def shutdown_database(self) -> None: # pragma: no cover + """Clean up database connections.""" + await db.shutdown_db() + + +# Module-level container instance (set by lifespan) +# This allows deps.py to access the container without reading ConfigManager +_container: ApiContainer | None = None + + +def get_container() -> ApiContainer: + """Get the current API container. + + Raises: + RuntimeError: If container hasn't been initialized + """ + if _container is None: + raise RuntimeError("API container not initialized. Call set_container() first.") + return _container + + +def set_container(container: ApiContainer) -> None: + """Set the API container (called by lifespan).""" + global _container + _container = container diff --git a/src/basic_memory/api/v2/routers/knowledge_router.py b/src/basic_memory/api/v2/routers/knowledge_router.py index d3401797..77f3fd7f 100644 --- a/src/basic_memory/api/v2/routers/knowledge_router.py +++ b/src/basic_memory/api/v2/routers/knowledge_router.py @@ -285,7 +285,7 @@ async def edit_entity_by_id( # Verify entity exists entity = await entity_repository.get_by_external_id(entity_id) - if not entity: + if not entity: # pragma: no cover raise HTTPException( status_code=404, detail=f"Entity with external_id '{entity_id}' not found" ) @@ -394,7 +394,7 @@ async def move_entity( try: # First, get the entity by external_id to verify it exists entity = await entity_repository.get_by_external_id(entity_id) - if not entity: + if not entity: # pragma: no cover raise HTTPException( status_code=404, detail=f"Entity with external_id '{entity_id}' not found" ) diff --git a/src/basic_memory/api/v2/routers/project_router.py b/src/basic_memory/api/v2/routers/project_router.py index 9c9b754e..47365659 100644 --- a/src/basic_memory/api/v2/routers/project_router.py +++ b/src/basic_memory/api/v2/routers/project_router.py @@ -202,7 +202,7 @@ async def update_project_by_id( # Get updated project info (use the same external_id) updated_project = await project_repository.get_by_external_id(project_id) - if not updated_project: + if not updated_project: # pragma: no cover raise HTTPException( status_code=404, detail=f"Project with external_id '{project_id}' not found after update", diff --git a/src/basic_memory/cli/app.py b/src/basic_memory/cli/app.py index f297878c..ea749735 100644 --- a/src/basic_memory/cli/app.py +++ b/src/basic_memory/cli/app.py @@ -14,7 +14,8 @@ import typer # noqa: E402 -from basic_memory.config import ConfigManager, init_cli_logging # noqa: E402 +from basic_memory.cli.container import CliContainer, set_container # noqa: E402 +from basic_memory.config import init_cli_logging # noqa: E402 from basic_memory.telemetry import show_notice_if_needed, track_app_started # noqa: E402 @@ -47,6 +48,11 @@ def app_callback( # Initialize logging for CLI (file only, no stdout) init_cli_logging() + # --- Composition Root --- + # Create container and read config (single point of config access) + container = CliContainer.create() + set_container(container) + # Show telemetry notice and track CLI startup # Skip for 'mcp' command - it handles its own telemetry in lifespan # Skip for 'telemetry' command - avoid issues when user is managing telemetry @@ -65,8 +71,7 @@ def app_callback( ): from basic_memory.services.initialization import ensure_initialization - app_config = ConfigManager().config - ensure_initialization(app_config) + ensure_initialization(container.config) ## import diff --git a/src/basic_memory/cli/container.py b/src/basic_memory/cli/container.py new file mode 100644 index 00000000..e375923c --- /dev/null +++ b/src/basic_memory/cli/container.py @@ -0,0 +1,84 @@ +"""CLI composition root for Basic Memory. + +This container owns reading ConfigManager and environment variables for the +CLI entrypoint. Downstream modules receive config/dependencies explicitly +rather than reading globals. + +Design principles: +- Only this module reads ConfigManager directly +- Runtime mode (cloud/local/test) is resolved here +- Different CLI commands may need different initialization +""" + +from dataclasses import dataclass + +from basic_memory.config import BasicMemoryConfig, ConfigManager +from basic_memory.runtime import RuntimeMode, resolve_runtime_mode + + +@dataclass +class CliContainer: + """Composition root for the CLI entrypoint. + + Holds resolved configuration and runtime context. + Created once at CLI startup, then used by subcommands. + """ + + config: BasicMemoryConfig + mode: RuntimeMode + + @classmethod + def create(cls) -> "CliContainer": + """Create container by reading ConfigManager. + + This is the single point where CLI reads global config. + """ + config = ConfigManager().config + mode = resolve_runtime_mode( + cloud_mode_enabled=config.cloud_mode_enabled, + is_test_env=config.is_test_env, + ) + return cls(config=config, mode=mode) + + # --- Runtime Mode Properties --- + + @property + def is_cloud_mode(self) -> bool: + """Whether running in cloud mode.""" + return self.mode.is_cloud + + +# Module-level container instance (set by app callback) +_container: CliContainer | None = None + + +def get_container() -> CliContainer: + """Get the current CLI container. + + Returns: + The CLI container + + Raises: + RuntimeError: If container hasn't been initialized + """ + if _container is None: + raise RuntimeError("CLI container not initialized. Call set_container() first.") + return _container + + +def set_container(container: CliContainer) -> None: + """Set the CLI container (called by app callback).""" + global _container + _container = container + + +def get_or_create_container() -> CliContainer: + """Get existing container or create new one. + + This is useful for CLI commands that might be called before + the main app callback runs (e.g., eager options). + """ + global _container + if _container is None: + _container = CliContainer.create() + return _container diff --git a/src/basic_memory/db.py b/src/basic_memory/db.py index 73921880..c440e37d 100644 --- a/src/basic_memory/db.py +++ b/src/basic_memory/db.py @@ -242,18 +242,24 @@ def _create_postgres_engine(db_url: str, config: BasicMemoryConfig) -> AsyncEngi def _create_engine_and_session( - db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM + db_path: Path, + db_type: DatabaseType = DatabaseType.FILESYSTEM, + config: Optional[BasicMemoryConfig] = None, ) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: """Internal helper to create engine and session maker. Args: db_path: Path to database file (used for SQLite, ignored for Postgres) db_type: Type of database (MEMORY, FILESYSTEM, or POSTGRES) + config: Optional explicit config. If not provided, reads from ConfigManager. + Prefer passing explicitly from composition roots. Returns: Tuple of (engine, session_maker) """ - config = ConfigManager().config + # Prefer explicit parameter; fall back to ConfigManager for backwards compatibility + if config is None: + config = ConfigManager().config db_url = DatabaseType.get_db_url(db_path, db_type, config) logger.debug(f"Creating engine for db_url: {db_url}") @@ -272,17 +278,29 @@ async def get_or_create_db( db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM, ensure_migrations: bool = True, + config: Optional[BasicMemoryConfig] = None, ) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover - """Get or create database engine and session maker.""" + """Get or create database engine and session maker. + + Args: + db_path: Path to database file + db_type: Type of database + ensure_migrations: Whether to run migrations + config: Optional explicit config. If not provided, reads from ConfigManager. + Prefer passing explicitly from composition roots. + """ global _engine, _session_maker + # Prefer explicit parameter; fall back to ConfigManager for backwards compatibility + if config is None: + config = ConfigManager().config + if _engine is None: - _engine, _session_maker = _create_engine_and_session(db_path, db_type) + _engine, _session_maker = _create_engine_and_session(db_path, db_type, config) # Run migrations automatically unless explicitly disabled if ensure_migrations: - app_config = ConfigManager().config - await run_migrations(app_config, db_type) + await run_migrations(config, db_type) # These checks should never fail since we just created the engine and session maker # if they were None, but we'll check anyway for the type checker @@ -311,17 +329,23 @@ async def shutdown_db() -> None: # pragma: no cover async def engine_session_factory( db_path: Path, db_type: DatabaseType = DatabaseType.MEMORY, + config: Optional[BasicMemoryConfig] = None, ) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]: """Create engine and session factory. Note: This is primarily used for testing where we want a fresh database for each test. For production use, use get_or_create_db() instead. + + Args: + db_path: Path to database file + db_type: Type of database + config: Optional explicit config. If not provided, reads from ConfigManager. """ global _engine, _session_maker # Use the same helper function as production code - _engine, _session_maker = _create_engine_and_session(db_path, db_type) + _engine, _session_maker = _create_engine_and_session(db_path, db_type, config) try: # Verify that engine and session maker are initialized diff --git a/src/basic_memory/deps.py b/src/basic_memory/deps.py index 2ed24db8..f5065ad4 100644 --- a/src/basic_memory/deps.py +++ b/src/basic_memory/deps.py @@ -1,1014 +1,16 @@ -"""Dependency injection functions for basic-memory services.""" - -from typing import Annotated -from loguru import logger - -from fastapi import Depends, HTTPException, Path, status, Request -from sqlalchemy.ext.asyncio import ( - AsyncSession, - AsyncEngine, - async_sessionmaker, -) -import pathlib - -from basic_memory import db -from basic_memory.config import ProjectConfig, BasicMemoryConfig, ConfigManager -from basic_memory.importers import ( - ChatGPTImporter, - ClaudeConversationsImporter, - ClaudeProjectsImporter, - MemoryJsonImporter, -) -from basic_memory.markdown import EntityParser -from basic_memory.markdown.markdown_processor import MarkdownProcessor -from basic_memory.repository.entity_repository import EntityRepository -from basic_memory.repository.observation_repository import ObservationRepository -from basic_memory.repository.project_repository import ProjectRepository -from basic_memory.repository.relation_repository import RelationRepository -from basic_memory.repository.search_repository import SearchRepository, create_search_repository -from basic_memory.services import EntityService, ProjectService -from basic_memory.services.context_service import ContextService -from basic_memory.services.directory_service import DirectoryService -from basic_memory.services.file_service import FileService -from basic_memory.services.link_resolver import LinkResolver -from basic_memory.services.search_service import SearchService -from basic_memory.sync import SyncService -from basic_memory.utils import generate_permalink - - -def get_app_config() -> BasicMemoryConfig: # pragma: no cover - app_config = ConfigManager().config - return app_config - - -AppConfigDep = Annotated[BasicMemoryConfig, Depends(get_app_config)] # pragma: no cover - - -## project - - -async def get_project_config( - project: "ProjectPathDep", project_repository: "ProjectRepositoryDep" -) -> ProjectConfig: # pragma: no cover - """Get the current project referenced from request state. - - Args: - request: The current request object - project_repository: Repository for project operations - - Returns: - The resolved project config - - Raises: - HTTPException: If project is not found - """ - # Convert project name to permalink for lookup - project_permalink = generate_permalink(str(project)) - project_obj = await project_repository.get_by_permalink(project_permalink) - if project_obj: - return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) - - # Not found - raise HTTPException( # pragma: no cover - status_code=status.HTTP_404_NOT_FOUND, detail=f"Project '{project}' not found." - ) - - -ProjectConfigDep = Annotated[ProjectConfig, Depends(get_project_config)] # pragma: no cover - - -async def get_project_config_v2( - project_id: "ProjectIdPathDep", project_repository: "ProjectRepositoryDep" -) -> ProjectConfig: # pragma: no cover - """Get the project config for v2 API (uses integer project_id from path). - - Args: - project_id: The validated numeric project ID from the URL path - project_repository: Repository for project operations - - Returns: - The resolved project config - - Raises: - HTTPException: If project is not found - """ - project_obj = await project_repository.get_by_id(project_id) - if project_obj: - return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) - - # Not found (this should not happen since ProjectIdPathDep already validates existence) - raise HTTPException( # pragma: no cover - status_code=status.HTTP_404_NOT_FOUND, detail=f"Project with ID {project_id} not found." - ) - - -ProjectConfigV2Dep = Annotated[ProjectConfig, Depends(get_project_config_v2)] - - -async def get_project_config_v2_external( - project_id: "ProjectExternalIdPathDep", project_repository: "ProjectRepositoryDep" -) -> ProjectConfig: # pragma: no cover - """Get the project config for v2 API (uses external_id UUID from path). - - Args: - project_id: The internal project ID resolved from external_id - project_repository: Repository for project operations - - Returns: - The resolved project config - - Raises: - HTTPException: If project is not found - """ - project_obj = await project_repository.get_by_id(project_id) - if project_obj: - return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) - - # Not found (this should not happen since ProjectExternalIdPathDep already validates) - raise HTTPException( # pragma: no cover - status_code=status.HTTP_404_NOT_FOUND, detail=f"Project with ID {project_id} not found." - ) - - -ProjectConfigV2ExternalDep = Annotated[ProjectConfig, Depends(get_project_config_v2_external)] # pragma: no cover - -## sqlalchemy - - -async def get_engine_factory( - request: Request, -) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover - """Get cached engine and session maker from app state. - - For API requests, returns cached connections from app.state for optimal performance. - For non-API contexts (CLI), falls back to direct database connection. - """ - # Try to get cached connections from app state (API context) - if ( - hasattr(request, "app") - and hasattr(request.app.state, "engine") - and hasattr(request.app.state, "session_maker") - ): - return request.app.state.engine, request.app.state.session_maker - - # Fallback for non-API contexts (CLI) - logger.debug("Using fallback database connection for non-API context") - app_config = get_app_config() - engine, session_maker = await db.get_or_create_db(app_config.database_path) - return engine, session_maker - - -EngineFactoryDep = Annotated[ - tuple[AsyncEngine, async_sessionmaker[AsyncSession]], Depends(get_engine_factory) -] - - -async def get_session_maker(engine_factory: EngineFactoryDep) -> async_sessionmaker[AsyncSession]: - """Get session maker.""" - _, session_maker = engine_factory - return session_maker - - -SessionMakerDep = Annotated[async_sessionmaker, Depends(get_session_maker)] - - -## repositories - - -async def get_project_repository( - session_maker: SessionMakerDep, -) -> ProjectRepository: - """Get the project repository.""" - return ProjectRepository(session_maker) - - -ProjectRepositoryDep = Annotated[ProjectRepository, Depends(get_project_repository)] -ProjectPathDep = Annotated[str, Path()] # Use Path dependency to extract from URL - - -async def validate_project_id( - project_id: int, - project_repository: ProjectRepositoryDep, -) -> int: - """Validate that a numeric project ID exists in the database. - - This is used for v2 API endpoints that take project IDs as integers in the path. - The project_id parameter will be automatically extracted from the URL path by FastAPI. - - Args: - project_id: The numeric project ID from the URL path - project_repository: Repository for project operations - - Returns: - The validated project ID - - Raises: - HTTPException: If project with that ID is not found - """ - project_obj = await project_repository.get_by_id(project_id) - if not project_obj: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Project with ID {project_id} not found.", - ) - return project_id - - -# V2 API: Validated integer project ID from path -ProjectIdPathDep = Annotated[int, Depends(validate_project_id)] - - -async def validate_project_external_id( - project_id: str, - project_repository: ProjectRepositoryDep, -) -> int: - """Validate that a project external_id (UUID) exists in the database. - - This is used for v2 API endpoints that take project external_ids as strings in the path. - The project_id parameter will be automatically extracted from the URL path by FastAPI. - - Args: - project_id: The external UUID from the URL path (named project_id for URL consistency) - project_repository: Repository for project operations - - Returns: - The internal numeric project ID (for use by repositories) - - Raises: - HTTPException: If project with that external_id is not found - """ - project_obj = await project_repository.get_by_external_id(project_id) - if not project_obj: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Project with external_id '{project_id}' not found.", - ) - return project_obj.id - - -# V2 API: Validated external UUID project ID from path (returns internal int ID) -ProjectExternalIdPathDep = Annotated[int, Depends(validate_project_external_id)] - - -async def get_project_id( - project_repository: ProjectRepositoryDep, - project: ProjectPathDep, -) -> int: - """Get the current project ID from request state. - - When using sub-applications with /{project} mounting, the project value - is stored in request.state by middleware. - - Args: - request: The current request object - project_repository: Repository for project operations - - Returns: - The resolved project ID - - Raises: - HTTPException: If project is not found - """ - # Convert project name to permalink for lookup - project_permalink = generate_permalink(str(project)) - project_obj = await project_repository.get_by_permalink(project_permalink) - if project_obj: - return project_obj.id - - # Try by name if permalink lookup fails - project_obj = await project_repository.get_by_name(str(project)) # pragma: no cover - if project_obj: # pragma: no cover - return project_obj.id - - # Not found - raise HTTPException( # pragma: no cover - status_code=status.HTTP_404_NOT_FOUND, detail=f"Project '{project}' not found." - ) - - -""" -The project_id dependency is used in the following: -- EntityRepository -- ObservationRepository -- RelationRepository -- SearchRepository -- ProjectInfoRepository +"""Dependency injection functions for basic-memory services. + +DEPRECATED: This module is a backwards-compatibility shim. +Import from basic_memory.deps package submodules instead: +- basic_memory.deps.config for configuration +- basic_memory.deps.db for database/session +- basic_memory.deps.projects for project resolution +- basic_memory.deps.repositories for data access +- basic_memory.deps.services for business logic +- basic_memory.deps.importers for import functionality + +This file will be removed once all callers are migrated. """ -ProjectIdDep = Annotated[int, Depends(get_project_id)] - - -async def get_entity_repository( - session_maker: SessionMakerDep, - project_id: ProjectIdDep, -) -> EntityRepository: - """Create an EntityRepository instance for the current project.""" - return EntityRepository(session_maker, project_id=project_id) - - -EntityRepositoryDep = Annotated[EntityRepository, Depends(get_entity_repository)] - - -async def get_entity_repository_v2( - session_maker: SessionMakerDep, - project_id: ProjectIdPathDep, -) -> EntityRepository: - """Create an EntityRepository instance for v2 API (uses integer project_id from path).""" - return EntityRepository(session_maker, project_id=project_id) - - -EntityRepositoryV2Dep = Annotated[EntityRepository, Depends(get_entity_repository_v2)] - - -async def get_entity_repository_v2_external( - session_maker: SessionMakerDep, - project_id: ProjectExternalIdPathDep, -) -> EntityRepository: - """Create an EntityRepository instance for v2 API (uses external_id from path).""" - return EntityRepository(session_maker, project_id=project_id) - - -EntityRepositoryV2ExternalDep = Annotated[EntityRepository, Depends(get_entity_repository_v2_external)] - - -async def get_observation_repository( - session_maker: SessionMakerDep, - project_id: ProjectIdDep, -) -> ObservationRepository: - """Create an ObservationRepository instance for the current project.""" - return ObservationRepository(session_maker, project_id=project_id) - - -ObservationRepositoryDep = Annotated[ObservationRepository, Depends(get_observation_repository)] - - -async def get_observation_repository_v2( - session_maker: SessionMakerDep, - project_id: ProjectIdPathDep, -) -> ObservationRepository: - """Create an ObservationRepository instance for v2 API.""" - return ObservationRepository(session_maker, project_id=project_id) - - -ObservationRepositoryV2Dep = Annotated[ - ObservationRepository, Depends(get_observation_repository_v2) -] - - -async def get_observation_repository_v2_external( - session_maker: SessionMakerDep, - project_id: ProjectExternalIdPathDep, -) -> ObservationRepository: - """Create an ObservationRepository instance for v2 API (uses external_id).""" - return ObservationRepository(session_maker, project_id=project_id) - - -ObservationRepositoryV2ExternalDep = Annotated[ - ObservationRepository, Depends(get_observation_repository_v2_external) -] - - -async def get_relation_repository( - session_maker: SessionMakerDep, - project_id: ProjectIdDep, -) -> RelationRepository: - """Create a RelationRepository instance for the current project.""" - return RelationRepository(session_maker, project_id=project_id) - - -RelationRepositoryDep = Annotated[RelationRepository, Depends(get_relation_repository)] - - -async def get_relation_repository_v2( - session_maker: SessionMakerDep, - project_id: ProjectIdPathDep, -) -> RelationRepository: - """Create a RelationRepository instance for v2 API.""" - return RelationRepository(session_maker, project_id=project_id) - - -RelationRepositoryV2Dep = Annotated[RelationRepository, Depends(get_relation_repository_v2)] - - -async def get_relation_repository_v2_external( - session_maker: SessionMakerDep, - project_id: ProjectExternalIdPathDep, -) -> RelationRepository: - """Create a RelationRepository instance for v2 API (uses external_id).""" - return RelationRepository(session_maker, project_id=project_id) - - -RelationRepositoryV2ExternalDep = Annotated[RelationRepository, Depends(get_relation_repository_v2_external)] - - -async def get_search_repository( - session_maker: SessionMakerDep, - project_id: ProjectIdDep, -) -> SearchRepository: - """Create a backend-specific SearchRepository instance for the current project. - - Uses factory function to return SQLiteSearchRepository or PostgresSearchRepository - based on database backend configuration. - """ - return create_search_repository(session_maker, project_id=project_id) - - -SearchRepositoryDep = Annotated[SearchRepository, Depends(get_search_repository)] - - -async def get_search_repository_v2( - session_maker: SessionMakerDep, - project_id: ProjectIdPathDep, -) -> SearchRepository: - """Create a SearchRepository instance for v2 API.""" - return create_search_repository(session_maker, project_id=project_id) - - -SearchRepositoryV2Dep = Annotated[SearchRepository, Depends(get_search_repository_v2)] - - -async def get_search_repository_v2_external( - session_maker: SessionMakerDep, - project_id: ProjectExternalIdPathDep, -) -> SearchRepository: - """Create a SearchRepository instance for v2 API (uses external_id).""" - return create_search_repository(session_maker, project_id=project_id) - - -SearchRepositoryV2ExternalDep = Annotated[SearchRepository, Depends(get_search_repository_v2_external)] - - -# ProjectInfoRepository is deprecated and will be removed in a future version. -# Use ProjectRepository instead, which has the same functionality plus more project-specific operations. - -## services - - -async def get_entity_parser(project_config: ProjectConfigDep) -> EntityParser: - return EntityParser(project_config.home) - - -EntityParserDep = Annotated["EntityParser", Depends(get_entity_parser)] - - -async def get_entity_parser_v2(project_config: ProjectConfigV2Dep) -> EntityParser: - return EntityParser(project_config.home) - - -EntityParserV2Dep = Annotated["EntityParser", Depends(get_entity_parser_v2)] - - -async def get_entity_parser_v2_external(project_config: ProjectConfigV2ExternalDep) -> EntityParser: - return EntityParser(project_config.home) - - -EntityParserV2ExternalDep = Annotated["EntityParser", Depends(get_entity_parser_v2_external)] - - -async def get_markdown_processor( - entity_parser: EntityParserDep, app_config: AppConfigDep -) -> MarkdownProcessor: - return MarkdownProcessor(entity_parser, app_config=app_config) - - -MarkdownProcessorDep = Annotated[MarkdownProcessor, Depends(get_markdown_processor)] - - -async def get_markdown_processor_v2( - entity_parser: EntityParserV2Dep, app_config: AppConfigDep -) -> MarkdownProcessor: - return MarkdownProcessor(entity_parser, app_config=app_config) - - -MarkdownProcessorV2Dep = Annotated[MarkdownProcessor, Depends(get_markdown_processor_v2)] - - -async def get_markdown_processor_v2_external( - entity_parser: EntityParserV2ExternalDep, app_config: AppConfigDep -) -> MarkdownProcessor: - return MarkdownProcessor(entity_parser, app_config=app_config) - - -MarkdownProcessorV2ExternalDep = Annotated[MarkdownProcessor, Depends(get_markdown_processor_v2_external)] - - -async def get_file_service( - project_config: ProjectConfigDep, - markdown_processor: MarkdownProcessorDep, - app_config: AppConfigDep, -) -> FileService: - file_service = FileService(project_config.home, markdown_processor, app_config=app_config) - logger.debug( - f"Created FileService for project: {project_config.name}, base_path: {project_config.home} " - ) - return file_service - - -FileServiceDep = Annotated[FileService, Depends(get_file_service)] - - -async def get_file_service_v2( - project_config: ProjectConfigV2Dep, - markdown_processor: MarkdownProcessorV2Dep, - app_config: AppConfigDep, -) -> FileService: - file_service = FileService(project_config.home, markdown_processor, app_config=app_config) - logger.debug( - f"Created FileService for project: {project_config.name}, base_path: {project_config.home}" - ) - return file_service - - -FileServiceV2Dep = Annotated[FileService, Depends(get_file_service_v2)] - - -async def get_file_service_v2_external( - project_config: ProjectConfigV2ExternalDep, - markdown_processor: MarkdownProcessorV2ExternalDep, - app_config: AppConfigDep, -) -> FileService: - file_service = FileService(project_config.home, markdown_processor, app_config=app_config) - logger.debug( - f"Created FileService for project: {project_config.name}, base_path: {project_config.home}" - ) - return file_service - - -FileServiceV2ExternalDep = Annotated[FileService, Depends(get_file_service_v2_external)] - - -async def get_entity_service( - entity_repository: EntityRepositoryDep, - observation_repository: ObservationRepositoryDep, - relation_repository: RelationRepositoryDep, - entity_parser: EntityParserDep, - file_service: FileServiceDep, - link_resolver: "LinkResolverDep", - search_service: "SearchServiceDep", - app_config: AppConfigDep, -) -> EntityService: - """Create EntityService with repository.""" - return EntityService( - entity_repository=entity_repository, - observation_repository=observation_repository, - relation_repository=relation_repository, - entity_parser=entity_parser, - file_service=file_service, - link_resolver=link_resolver, - search_service=search_service, - app_config=app_config, - ) - - -EntityServiceDep = Annotated[EntityService, Depends(get_entity_service)] - - -async def get_entity_service_v2( - entity_repository: EntityRepositoryV2Dep, - observation_repository: ObservationRepositoryV2Dep, - relation_repository: RelationRepositoryV2Dep, - entity_parser: EntityParserV2Dep, - file_service: FileServiceV2Dep, - link_resolver: "LinkResolverV2Dep", - search_service: "SearchServiceV2Dep", - app_config: AppConfigDep, -) -> EntityService: - """Create EntityService for v2 API.""" - return EntityService( - entity_repository=entity_repository, - observation_repository=observation_repository, - relation_repository=relation_repository, - entity_parser=entity_parser, - file_service=file_service, - link_resolver=link_resolver, - search_service=search_service, - app_config=app_config, - ) - - -EntityServiceV2Dep = Annotated[EntityService, Depends(get_entity_service_v2)] - - -async def get_entity_service_v2_external( - entity_repository: EntityRepositoryV2ExternalDep, - observation_repository: ObservationRepositoryV2ExternalDep, - relation_repository: RelationRepositoryV2ExternalDep, - entity_parser: EntityParserV2ExternalDep, - file_service: FileServiceV2ExternalDep, - link_resolver: "LinkResolverV2ExternalDep", - search_service: "SearchServiceV2ExternalDep", - app_config: AppConfigDep, -) -> EntityService: - """Create EntityService for v2 API (uses external_id).""" - return EntityService( - entity_repository=entity_repository, - observation_repository=observation_repository, - relation_repository=relation_repository, - entity_parser=entity_parser, - file_service=file_service, - link_resolver=link_resolver, - search_service=search_service, - app_config=app_config, - ) - - -EntityServiceV2ExternalDep = Annotated[EntityService, Depends(get_entity_service_v2_external)] - - -async def get_search_service( - search_repository: SearchRepositoryDep, - entity_repository: EntityRepositoryDep, - file_service: FileServiceDep, -) -> SearchService: - """Create SearchService with dependencies.""" - return SearchService(search_repository, entity_repository, file_service) - - -SearchServiceDep = Annotated[SearchService, Depends(get_search_service)] - - -async def get_search_service_v2( - search_repository: SearchRepositoryV2Dep, - entity_repository: EntityRepositoryV2Dep, - file_service: FileServiceV2Dep, -) -> SearchService: - """Create SearchService for v2 API.""" - return SearchService(search_repository, entity_repository, file_service) - - -SearchServiceV2Dep = Annotated[SearchService, Depends(get_search_service_v2)] - - -async def get_search_service_v2_external( - search_repository: SearchRepositoryV2ExternalDep, - entity_repository: EntityRepositoryV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> SearchService: - """Create SearchService for v2 API (uses external_id).""" - return SearchService(search_repository, entity_repository, file_service) - - -SearchServiceV2ExternalDep = Annotated[SearchService, Depends(get_search_service_v2_external)] - - -async def get_link_resolver( - entity_repository: EntityRepositoryDep, search_service: SearchServiceDep -) -> LinkResolver: - return LinkResolver(entity_repository=entity_repository, search_service=search_service) - - -LinkResolverDep = Annotated[LinkResolver, Depends(get_link_resolver)] - - -async def get_link_resolver_v2( - entity_repository: EntityRepositoryV2Dep, search_service: SearchServiceV2Dep -) -> LinkResolver: - return LinkResolver(entity_repository=entity_repository, search_service=search_service) - - -LinkResolverV2Dep = Annotated[LinkResolver, Depends(get_link_resolver_v2)] - - -async def get_link_resolver_v2_external( - entity_repository: EntityRepositoryV2ExternalDep, search_service: SearchServiceV2ExternalDep -) -> LinkResolver: - return LinkResolver(entity_repository=entity_repository, search_service=search_service) - - -LinkResolverV2ExternalDep = Annotated[LinkResolver, Depends(get_link_resolver_v2_external)] - - -async def get_context_service( - search_repository: SearchRepositoryDep, - entity_repository: EntityRepositoryDep, - observation_repository: ObservationRepositoryDep, -) -> ContextService: - return ContextService( - search_repository=search_repository, - entity_repository=entity_repository, - observation_repository=observation_repository, - ) - - -ContextServiceDep = Annotated[ContextService, Depends(get_context_service)] - - -async def get_context_service_v2( - search_repository: SearchRepositoryV2Dep, - entity_repository: EntityRepositoryV2Dep, - observation_repository: ObservationRepositoryV2Dep, -) -> ContextService: - """Create ContextService for v2 API.""" - return ContextService( - search_repository=search_repository, - entity_repository=entity_repository, - observation_repository=observation_repository, - ) - - -ContextServiceV2Dep = Annotated[ContextService, Depends(get_context_service_v2)] - - -async def get_context_service_v2_external( - search_repository: SearchRepositoryV2ExternalDep, - entity_repository: EntityRepositoryV2ExternalDep, - observation_repository: ObservationRepositoryV2ExternalDep, -) -> ContextService: - """Create ContextService for v2 API (uses external_id).""" - return ContextService( - search_repository=search_repository, - entity_repository=entity_repository, - observation_repository=observation_repository, - ) - - -ContextServiceV2ExternalDep = Annotated[ContextService, Depends(get_context_service_v2_external)] - - -async def get_sync_service( - app_config: AppConfigDep, - entity_service: EntityServiceDep, - entity_parser: EntityParserDep, - entity_repository: EntityRepositoryDep, - relation_repository: RelationRepositoryDep, - project_repository: ProjectRepositoryDep, - search_service: SearchServiceDep, - file_service: FileServiceDep, -) -> SyncService: # pragma: no cover - """ - - :rtype: object - """ - return SyncService( - app_config=app_config, - entity_service=entity_service, - entity_parser=entity_parser, - entity_repository=entity_repository, - relation_repository=relation_repository, - project_repository=project_repository, - search_service=search_service, - file_service=file_service, - ) - - -SyncServiceDep = Annotated[SyncService, Depends(get_sync_service)] - - -async def get_sync_service_v2( - app_config: AppConfigDep, - entity_service: EntityServiceV2Dep, - entity_parser: EntityParserV2Dep, - entity_repository: EntityRepositoryV2Dep, - relation_repository: RelationRepositoryV2Dep, - project_repository: ProjectRepositoryDep, - search_service: SearchServiceV2Dep, - file_service: FileServiceV2Dep, -) -> SyncService: # pragma: no cover - """Create SyncService for v2 API.""" - return SyncService( - app_config=app_config, - entity_service=entity_service, - entity_parser=entity_parser, - entity_repository=entity_repository, - relation_repository=relation_repository, - project_repository=project_repository, - search_service=search_service, - file_service=file_service, - ) - - -SyncServiceV2Dep = Annotated[SyncService, Depends(get_sync_service_v2)] - - -async def get_sync_service_v2_external( - app_config: AppConfigDep, - entity_service: EntityServiceV2ExternalDep, - entity_parser: EntityParserV2ExternalDep, - entity_repository: EntityRepositoryV2ExternalDep, - relation_repository: RelationRepositoryV2ExternalDep, - project_repository: ProjectRepositoryDep, - search_service: SearchServiceV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> SyncService: # pragma: no cover - """Create SyncService for v2 API (uses external_id).""" - return SyncService( - app_config=app_config, - entity_service=entity_service, - entity_parser=entity_parser, - entity_repository=entity_repository, - relation_repository=relation_repository, - project_repository=project_repository, - search_service=search_service, - file_service=file_service, - ) - - -SyncServiceV2ExternalDep = Annotated[SyncService, Depends(get_sync_service_v2_external)] - - -async def get_project_service( - project_repository: ProjectRepositoryDep, -) -> ProjectService: - """Create ProjectService with repository.""" - return ProjectService(repository=project_repository) - - -ProjectServiceDep = Annotated[ProjectService, Depends(get_project_service)] - - -async def get_directory_service( - entity_repository: EntityRepositoryDep, -) -> DirectoryService: - """Create DirectoryService with dependencies.""" - return DirectoryService( - entity_repository=entity_repository, - ) - - -DirectoryServiceDep = Annotated[DirectoryService, Depends(get_directory_service)] - - -async def get_directory_service_v2( - entity_repository: EntityRepositoryV2Dep, -) -> DirectoryService: - """Create DirectoryService for v2 API (uses integer project_id from path).""" - return DirectoryService( - entity_repository=entity_repository, - ) - - -DirectoryServiceV2Dep = Annotated[DirectoryService, Depends(get_directory_service_v2)] - - -async def get_directory_service_v2_external( - entity_repository: EntityRepositoryV2ExternalDep, -) -> DirectoryService: - """Create DirectoryService for v2 API (uses external_id from path).""" - return DirectoryService( - entity_repository=entity_repository, - ) - - -DirectoryServiceV2ExternalDep = Annotated[DirectoryService, Depends(get_directory_service_v2_external)] - - -# Import - - -async def get_chatgpt_importer( - project_config: ProjectConfigDep, - markdown_processor: MarkdownProcessorDep, - file_service: FileServiceDep, -) -> ChatGPTImporter: - """Create ChatGPTImporter with dependencies.""" - return ChatGPTImporter(project_config.home, markdown_processor, file_service) - - -ChatGPTImporterDep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer)] - - -async def get_claude_conversations_importer( - project_config: ProjectConfigDep, - markdown_processor: MarkdownProcessorDep, - file_service: FileServiceDep, -) -> ClaudeConversationsImporter: - """Create ClaudeConversationsImporter with dependencies.""" - return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeConversationsImporterDep = Annotated[ - ClaudeConversationsImporter, Depends(get_claude_conversations_importer) -] - - -async def get_claude_projects_importer( - project_config: ProjectConfigDep, - markdown_processor: MarkdownProcessorDep, - file_service: FileServiceDep, -) -> ClaudeProjectsImporter: - """Create ClaudeProjectsImporter with dependencies.""" - return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeProjectsImporterDep = Annotated[ClaudeProjectsImporter, Depends(get_claude_projects_importer)] - - -async def get_memory_json_importer( - project_config: ProjectConfigDep, - markdown_processor: MarkdownProcessorDep, - file_service: FileServiceDep, -) -> MemoryJsonImporter: - """Create MemoryJsonImporter with dependencies.""" - return MemoryJsonImporter(project_config.home, markdown_processor, file_service) - - -MemoryJsonImporterDep = Annotated[MemoryJsonImporter, Depends(get_memory_json_importer)] - - -# V2 Import dependencies - - -async def get_chatgpt_importer_v2( - project_config: ProjectConfigV2Dep, - markdown_processor: MarkdownProcessorV2Dep, - file_service: FileServiceV2Dep, -) -> ChatGPTImporter: - """Create ChatGPTImporter with v2 dependencies.""" - return ChatGPTImporter(project_config.home, markdown_processor, file_service) - - -ChatGPTImporterV2Dep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer_v2)] - - -async def get_claude_conversations_importer_v2( - project_config: ProjectConfigV2Dep, - markdown_processor: MarkdownProcessorV2Dep, - file_service: FileServiceV2Dep, -) -> ClaudeConversationsImporter: - """Create ClaudeConversationsImporter with v2 dependencies.""" - return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeConversationsImporterV2Dep = Annotated[ - ClaudeConversationsImporter, Depends(get_claude_conversations_importer_v2) -] - - -async def get_claude_projects_importer_v2( - project_config: ProjectConfigV2Dep, - markdown_processor: MarkdownProcessorV2Dep, - file_service: FileServiceV2Dep, -) -> ClaudeProjectsImporter: - """Create ClaudeProjectsImporter with v2 dependencies.""" - return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeProjectsImporterV2Dep = Annotated[ - ClaudeProjectsImporter, Depends(get_claude_projects_importer_v2) -] - - -async def get_memory_json_importer_v2( - project_config: ProjectConfigV2Dep, - markdown_processor: MarkdownProcessorV2Dep, - file_service: FileServiceV2Dep, -) -> MemoryJsonImporter: - """Create MemoryJsonImporter with v2 dependencies.""" - return MemoryJsonImporter(project_config.home, markdown_processor, file_service) - - -MemoryJsonImporterV2Dep = Annotated[MemoryJsonImporter, Depends(get_memory_json_importer_v2)] - - -# V2 External Import dependencies (using external_id) - - -async def get_chatgpt_importer_v2_external( - project_config: ProjectConfigV2ExternalDep, - markdown_processor: MarkdownProcessorV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> ChatGPTImporter: - """Create ChatGPTImporter with v2 external_id dependencies.""" - return ChatGPTImporter(project_config.home, markdown_processor, file_service) - - -ChatGPTImporterV2ExternalDep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer_v2_external)] - - -async def get_claude_conversations_importer_v2_external( - project_config: ProjectConfigV2ExternalDep, - markdown_processor: MarkdownProcessorV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> ClaudeConversationsImporter: - """Create ClaudeConversationsImporter with v2 external_id dependencies.""" - return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeConversationsImporterV2ExternalDep = Annotated[ - ClaudeConversationsImporter, Depends(get_claude_conversations_importer_v2_external) -] - - -async def get_claude_projects_importer_v2_external( - project_config: ProjectConfigV2ExternalDep, - markdown_processor: MarkdownProcessorV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> ClaudeProjectsImporter: - """Create ClaudeProjectsImporter with v2 external_id dependencies.""" - return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) - - -ClaudeProjectsImporterV2ExternalDep = Annotated[ - ClaudeProjectsImporter, Depends(get_claude_projects_importer_v2_external) -] - - -async def get_memory_json_importer_v2_external( - project_config: ProjectConfigV2ExternalDep, - markdown_processor: MarkdownProcessorV2ExternalDep, - file_service: FileServiceV2ExternalDep, -) -> MemoryJsonImporter: - """Create MemoryJsonImporter with v2 external_id dependencies.""" - return MemoryJsonImporter(project_config.home, markdown_processor, file_service) - -MemoryJsonImporterV2ExternalDep = Annotated[MemoryJsonImporter, Depends(get_memory_json_importer_v2_external)] +# Re-export everything from the deps package for backwards compatibility +from basic_memory.deps import * # noqa: F401, F403 # pragma: no cover diff --git a/src/basic_memory/deps/__init__.py b/src/basic_memory/deps/__init__.py new file mode 100644 index 00000000..9924e7b3 --- /dev/null +++ b/src/basic_memory/deps/__init__.py @@ -0,0 +1,293 @@ +"""Dependency injection for basic-memory. + +This package provides FastAPI dependencies organized by feature: +- config: Application configuration +- db: Database/session management +- projects: Project resolution and config +- repositories: Data access layer +- services: Business logic layer +- importers: Import functionality + +For backwards compatibility, all dependencies are re-exported from this module. +New code should import from specific submodules to reduce coupling. +""" + +# Re-export everything for backwards compatibility +# Eventually, callers should import from specific submodules + +from basic_memory.deps.config import ( + get_app_config, + AppConfigDep, +) + +from basic_memory.deps.db import ( + get_engine_factory, + EngineFactoryDep, + get_session_maker, + SessionMakerDep, +) + +from basic_memory.deps.projects import ( + get_project_repository, + ProjectRepositoryDep, + ProjectPathDep, + get_project_id, + ProjectIdDep, + get_project_config, + ProjectConfigDep, + validate_project_id, + ProjectIdPathDep, + get_project_config_v2, + ProjectConfigV2Dep, + validate_project_external_id, + ProjectExternalIdPathDep, + get_project_config_v2_external, + ProjectConfigV2ExternalDep, +) + +from basic_memory.deps.repositories import ( + get_entity_repository, + EntityRepositoryDep, + get_entity_repository_v2, + EntityRepositoryV2Dep, + get_entity_repository_v2_external, + EntityRepositoryV2ExternalDep, + get_observation_repository, + ObservationRepositoryDep, + get_observation_repository_v2, + ObservationRepositoryV2Dep, + get_observation_repository_v2_external, + ObservationRepositoryV2ExternalDep, + get_relation_repository, + RelationRepositoryDep, + get_relation_repository_v2, + RelationRepositoryV2Dep, + get_relation_repository_v2_external, + RelationRepositoryV2ExternalDep, + get_search_repository, + SearchRepositoryDep, + get_search_repository_v2, + SearchRepositoryV2Dep, + get_search_repository_v2_external, + SearchRepositoryV2ExternalDep, +) + +from basic_memory.deps.services import ( + get_entity_parser, + EntityParserDep, + get_entity_parser_v2, + EntityParserV2Dep, + get_entity_parser_v2_external, + EntityParserV2ExternalDep, + get_markdown_processor, + MarkdownProcessorDep, + get_markdown_processor_v2, + MarkdownProcessorV2Dep, + get_markdown_processor_v2_external, + MarkdownProcessorV2ExternalDep, + get_file_service, + FileServiceDep, + get_file_service_v2, + FileServiceV2Dep, + get_file_service_v2_external, + FileServiceV2ExternalDep, + get_search_service, + SearchServiceDep, + get_search_service_v2, + SearchServiceV2Dep, + get_search_service_v2_external, + SearchServiceV2ExternalDep, + get_link_resolver, + LinkResolverDep, + get_link_resolver_v2, + LinkResolverV2Dep, + get_link_resolver_v2_external, + LinkResolverV2ExternalDep, + get_entity_service, + EntityServiceDep, + get_entity_service_v2, + EntityServiceV2Dep, + get_entity_service_v2_external, + EntityServiceV2ExternalDep, + get_context_service, + ContextServiceDep, + get_context_service_v2, + ContextServiceV2Dep, + get_context_service_v2_external, + ContextServiceV2ExternalDep, + get_sync_service, + SyncServiceDep, + get_sync_service_v2, + SyncServiceV2Dep, + get_sync_service_v2_external, + SyncServiceV2ExternalDep, + get_project_service, + ProjectServiceDep, + get_directory_service, + DirectoryServiceDep, + get_directory_service_v2, + DirectoryServiceV2Dep, + get_directory_service_v2_external, + DirectoryServiceV2ExternalDep, +) + +from basic_memory.deps.importers import ( + get_chatgpt_importer, + ChatGPTImporterDep, + get_chatgpt_importer_v2, + ChatGPTImporterV2Dep, + get_chatgpt_importer_v2_external, + ChatGPTImporterV2ExternalDep, + get_claude_conversations_importer, + ClaudeConversationsImporterDep, + get_claude_conversations_importer_v2, + ClaudeConversationsImporterV2Dep, + get_claude_conversations_importer_v2_external, + ClaudeConversationsImporterV2ExternalDep, + get_claude_projects_importer, + ClaudeProjectsImporterDep, + get_claude_projects_importer_v2, + ClaudeProjectsImporterV2Dep, + get_claude_projects_importer_v2_external, + ClaudeProjectsImporterV2ExternalDep, + get_memory_json_importer, + MemoryJsonImporterDep, + get_memory_json_importer_v2, + MemoryJsonImporterV2Dep, + get_memory_json_importer_v2_external, + MemoryJsonImporterV2ExternalDep, +) + +__all__ = [ + # Config + "get_app_config", + "AppConfigDep", + # Database + "get_engine_factory", + "EngineFactoryDep", + "get_session_maker", + "SessionMakerDep", + # Projects + "get_project_repository", + "ProjectRepositoryDep", + "ProjectPathDep", + "get_project_id", + "ProjectIdDep", + "get_project_config", + "ProjectConfigDep", + "validate_project_id", + "ProjectIdPathDep", + "get_project_config_v2", + "ProjectConfigV2Dep", + "validate_project_external_id", + "ProjectExternalIdPathDep", + "get_project_config_v2_external", + "ProjectConfigV2ExternalDep", + # Repositories + "get_entity_repository", + "EntityRepositoryDep", + "get_entity_repository_v2", + "EntityRepositoryV2Dep", + "get_entity_repository_v2_external", + "EntityRepositoryV2ExternalDep", + "get_observation_repository", + "ObservationRepositoryDep", + "get_observation_repository_v2", + "ObservationRepositoryV2Dep", + "get_observation_repository_v2_external", + "ObservationRepositoryV2ExternalDep", + "get_relation_repository", + "RelationRepositoryDep", + "get_relation_repository_v2", + "RelationRepositoryV2Dep", + "get_relation_repository_v2_external", + "RelationRepositoryV2ExternalDep", + "get_search_repository", + "SearchRepositoryDep", + "get_search_repository_v2", + "SearchRepositoryV2Dep", + "get_search_repository_v2_external", + "SearchRepositoryV2ExternalDep", + # Services + "get_entity_parser", + "EntityParserDep", + "get_entity_parser_v2", + "EntityParserV2Dep", + "get_entity_parser_v2_external", + "EntityParserV2ExternalDep", + "get_markdown_processor", + "MarkdownProcessorDep", + "get_markdown_processor_v2", + "MarkdownProcessorV2Dep", + "get_markdown_processor_v2_external", + "MarkdownProcessorV2ExternalDep", + "get_file_service", + "FileServiceDep", + "get_file_service_v2", + "FileServiceV2Dep", + "get_file_service_v2_external", + "FileServiceV2ExternalDep", + "get_search_service", + "SearchServiceDep", + "get_search_service_v2", + "SearchServiceV2Dep", + "get_search_service_v2_external", + "SearchServiceV2ExternalDep", + "get_link_resolver", + "LinkResolverDep", + "get_link_resolver_v2", + "LinkResolverV2Dep", + "get_link_resolver_v2_external", + "LinkResolverV2ExternalDep", + "get_entity_service", + "EntityServiceDep", + "get_entity_service_v2", + "EntityServiceV2Dep", + "get_entity_service_v2_external", + "EntityServiceV2ExternalDep", + "get_context_service", + "ContextServiceDep", + "get_context_service_v2", + "ContextServiceV2Dep", + "get_context_service_v2_external", + "ContextServiceV2ExternalDep", + "get_sync_service", + "SyncServiceDep", + "get_sync_service_v2", + "SyncServiceV2Dep", + "get_sync_service_v2_external", + "SyncServiceV2ExternalDep", + "get_project_service", + "ProjectServiceDep", + "get_directory_service", + "DirectoryServiceDep", + "get_directory_service_v2", + "DirectoryServiceV2Dep", + "get_directory_service_v2_external", + "DirectoryServiceV2ExternalDep", + # Importers + "get_chatgpt_importer", + "ChatGPTImporterDep", + "get_chatgpt_importer_v2", + "ChatGPTImporterV2Dep", + "get_chatgpt_importer_v2_external", + "ChatGPTImporterV2ExternalDep", + "get_claude_conversations_importer", + "ClaudeConversationsImporterDep", + "get_claude_conversations_importer_v2", + "ClaudeConversationsImporterV2Dep", + "get_claude_conversations_importer_v2_external", + "ClaudeConversationsImporterV2ExternalDep", + "get_claude_projects_importer", + "ClaudeProjectsImporterDep", + "get_claude_projects_importer_v2", + "ClaudeProjectsImporterV2Dep", + "get_claude_projects_importer_v2_external", + "ClaudeProjectsImporterV2ExternalDep", + "get_memory_json_importer", + "MemoryJsonImporterDep", + "get_memory_json_importer_v2", + "MemoryJsonImporterV2Dep", + "get_memory_json_importer_v2_external", + "MemoryJsonImporterV2ExternalDep", +] diff --git a/src/basic_memory/deps/config.py b/src/basic_memory/deps/config.py new file mode 100644 index 00000000..b868151a --- /dev/null +++ b/src/basic_memory/deps/config.py @@ -0,0 +1,26 @@ +"""Configuration dependency injection for basic-memory. + +This module provides configuration-related dependencies. +Note: Long-term goal is to minimize direct ConfigManager access +and inject config from composition roots instead. +""" + +from typing import Annotated + +from fastapi import Depends + +from basic_memory.config import BasicMemoryConfig, ConfigManager + + +def get_app_config() -> BasicMemoryConfig: # pragma: no cover + """Get the application configuration. + + Note: This is a transitional dependency. The goal is for composition roots + to read ConfigManager and inject config explicitly. During migration, + this provides the same behavior as before. + """ + app_config = ConfigManager().config + return app_config + + +AppConfigDep = Annotated[BasicMemoryConfig, Depends(get_app_config)] diff --git a/src/basic_memory/deps/db.py b/src/basic_memory/deps/db.py new file mode 100644 index 00000000..0038ac47 --- /dev/null +++ b/src/basic_memory/deps/db.py @@ -0,0 +1,56 @@ +"""Database dependency injection for basic-memory. + +This module provides database-related dependencies: +- Engine and session maker factories +- Session dependencies for request handling +""" + +from typing import Annotated + +from fastapi import Depends, Request +from loguru import logger +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, +) + +from basic_memory import db +from basic_memory.deps.config import get_app_config + + +async def get_engine_factory( + request: Request, +) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover + """Get cached engine and session maker from app state. + + For API requests, returns cached connections from app.state for optimal performance. + For non-API contexts (CLI), falls back to direct database connection. + """ + # Try to get cached connections from app state (API context) + if ( + hasattr(request, "app") + and hasattr(request.app.state, "engine") + and hasattr(request.app.state, "session_maker") + ): + return request.app.state.engine, request.app.state.session_maker + + # Fallback for non-API contexts (CLI) + logger.debug("Using fallback database connection for non-API context") + app_config = get_app_config() + engine, session_maker = await db.get_or_create_db(app_config.database_path) + return engine, session_maker + + +EngineFactoryDep = Annotated[ + tuple[AsyncEngine, async_sessionmaker[AsyncSession]], Depends(get_engine_factory) +] + + +async def get_session_maker(engine_factory: EngineFactoryDep) -> async_sessionmaker[AsyncSession]: + """Get session maker.""" + _, session_maker = engine_factory + return session_maker + + +SessionMakerDep = Annotated[async_sessionmaker, Depends(get_session_maker)] diff --git a/src/basic_memory/deps/importers.py b/src/basic_memory/deps/importers.py new file mode 100644 index 00000000..2209192c --- /dev/null +++ b/src/basic_memory/deps/importers.py @@ -0,0 +1,200 @@ +"""Importer dependency injection for basic-memory. + +This module provides importer dependencies: +- ChatGPTImporter +- ClaudeConversationsImporter +- ClaudeProjectsImporter +- MemoryJsonImporter +""" + +from typing import Annotated + +from fastapi import Depends + +from basic_memory.deps.projects import ( + ProjectConfigDep, + ProjectConfigV2Dep, + ProjectConfigV2ExternalDep, +) +from basic_memory.deps.services import ( + FileServiceDep, + FileServiceV2Dep, + FileServiceV2ExternalDep, + MarkdownProcessorDep, + MarkdownProcessorV2Dep, + MarkdownProcessorV2ExternalDep, +) +from basic_memory.importers import ( + ChatGPTImporter, + ClaudeConversationsImporter, + ClaudeProjectsImporter, + MemoryJsonImporter, +) + + +# --- ChatGPT Importer --- + + +async def get_chatgpt_importer( + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + file_service: FileServiceDep, +) -> ChatGPTImporter: + """Create ChatGPTImporter with dependencies.""" + return ChatGPTImporter(project_config.home, markdown_processor, file_service) + + +ChatGPTImporterDep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer)] + + +async def get_chatgpt_importer_v2( # pragma: no cover + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + file_service: FileServiceV2Dep, +) -> ChatGPTImporter: + """Create ChatGPTImporter with v2 dependencies.""" + return ChatGPTImporter(project_config.home, markdown_processor, file_service) + + +ChatGPTImporterV2Dep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer_v2)] + + +async def get_chatgpt_importer_v2_external( + project_config: ProjectConfigV2ExternalDep, + markdown_processor: MarkdownProcessorV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> ChatGPTImporter: + """Create ChatGPTImporter with v2 external_id dependencies.""" + return ChatGPTImporter(project_config.home, markdown_processor, file_service) + + +ChatGPTImporterV2ExternalDep = Annotated[ChatGPTImporter, Depends(get_chatgpt_importer_v2_external)] + + +# --- Claude Conversations Importer --- + + +async def get_claude_conversations_importer( + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + file_service: FileServiceDep, +) -> ClaudeConversationsImporter: + """Create ClaudeConversationsImporter with dependencies.""" + return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeConversationsImporterDep = Annotated[ + ClaudeConversationsImporter, Depends(get_claude_conversations_importer) +] + + +async def get_claude_conversations_importer_v2( # pragma: no cover + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + file_service: FileServiceV2Dep, +) -> ClaudeConversationsImporter: + """Create ClaudeConversationsImporter with v2 dependencies.""" + return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeConversationsImporterV2Dep = Annotated[ + ClaudeConversationsImporter, Depends(get_claude_conversations_importer_v2) +] + + +async def get_claude_conversations_importer_v2_external( + project_config: ProjectConfigV2ExternalDep, + markdown_processor: MarkdownProcessorV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> ClaudeConversationsImporter: + """Create ClaudeConversationsImporter with v2 external_id dependencies.""" + return ClaudeConversationsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeConversationsImporterV2ExternalDep = Annotated[ + ClaudeConversationsImporter, Depends(get_claude_conversations_importer_v2_external) +] + + +# --- Claude Projects Importer --- + + +async def get_claude_projects_importer( + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + file_service: FileServiceDep, +) -> ClaudeProjectsImporter: + """Create ClaudeProjectsImporter with dependencies.""" + return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeProjectsImporterDep = Annotated[ClaudeProjectsImporter, Depends(get_claude_projects_importer)] + + +async def get_claude_projects_importer_v2( # pragma: no cover + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + file_service: FileServiceV2Dep, +) -> ClaudeProjectsImporter: + """Create ClaudeProjectsImporter with v2 dependencies.""" + return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeProjectsImporterV2Dep = Annotated[ + ClaudeProjectsImporter, Depends(get_claude_projects_importer_v2) +] + + +async def get_claude_projects_importer_v2_external( + project_config: ProjectConfigV2ExternalDep, + markdown_processor: MarkdownProcessorV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> ClaudeProjectsImporter: + """Create ClaudeProjectsImporter with v2 external_id dependencies.""" + return ClaudeProjectsImporter(project_config.home, markdown_processor, file_service) + + +ClaudeProjectsImporterV2ExternalDep = Annotated[ + ClaudeProjectsImporter, Depends(get_claude_projects_importer_v2_external) +] + + +# --- Memory JSON Importer --- + + +async def get_memory_json_importer( + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + file_service: FileServiceDep, +) -> MemoryJsonImporter: + """Create MemoryJsonImporter with dependencies.""" + return MemoryJsonImporter(project_config.home, markdown_processor, file_service) + + +MemoryJsonImporterDep = Annotated[MemoryJsonImporter, Depends(get_memory_json_importer)] + + +async def get_memory_json_importer_v2( # pragma: no cover + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + file_service: FileServiceV2Dep, +) -> MemoryJsonImporter: + """Create MemoryJsonImporter with v2 dependencies.""" + return MemoryJsonImporter(project_config.home, markdown_processor, file_service) + + +MemoryJsonImporterV2Dep = Annotated[MemoryJsonImporter, Depends(get_memory_json_importer_v2)] + + +async def get_memory_json_importer_v2_external( + project_config: ProjectConfigV2ExternalDep, + markdown_processor: MarkdownProcessorV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> MemoryJsonImporter: + """Create MemoryJsonImporter with v2 external_id dependencies.""" + return MemoryJsonImporter(project_config.home, markdown_processor, file_service) + + +MemoryJsonImporterV2ExternalDep = Annotated[ + MemoryJsonImporter, Depends(get_memory_json_importer_v2_external) +] diff --git a/src/basic_memory/deps/projects.py b/src/basic_memory/deps/projects.py new file mode 100644 index 00000000..2393e6d2 --- /dev/null +++ b/src/basic_memory/deps/projects.py @@ -0,0 +1,238 @@ +"""Project dependency injection for basic-memory. + +This module provides project-related dependencies: +- Project path extraction from URL +- Project config resolution +- Project ID validation +- Project repository +""" + +import pathlib +from typing import Annotated + +from fastapi import Depends, HTTPException, Path, status + +from basic_memory.config import ProjectConfig +from basic_memory.deps.db import SessionMakerDep +from basic_memory.repository.project_repository import ProjectRepository +from basic_memory.utils import generate_permalink + + +# --- Project Repository --- + + +async def get_project_repository( + session_maker: SessionMakerDep, +) -> ProjectRepository: + """Get the project repository.""" + return ProjectRepository(session_maker) + + +ProjectRepositoryDep = Annotated[ProjectRepository, Depends(get_project_repository)] + + +# --- Path Extraction --- + +# V1 API: Project name from URL path +ProjectPathDep = Annotated[str, Path()] + + +# --- Project ID Resolution (V1 API) --- + + +async def get_project_id( + project_repository: ProjectRepositoryDep, + project: ProjectPathDep, +) -> int: + """Get the current project ID from request state. + + When using sub-applications with /{project} mounting, the project value + is stored in request.state by middleware. + + Args: + project_repository: Repository for project operations + project: The project name from URL path + + Returns: + The resolved project ID + + Raises: + HTTPException: If project is not found + """ + # Convert project name to permalink for lookup + project_permalink = generate_permalink(str(project)) + project_obj = await project_repository.get_by_permalink(project_permalink) + if project_obj: + return project_obj.id + + # Try by name if permalink lookup fails + project_obj = await project_repository.get_by_name(str(project)) # pragma: no cover + if project_obj: # pragma: no cover + return project_obj.id + + # Not found + raise HTTPException( # pragma: no cover + status_code=status.HTTP_404_NOT_FOUND, detail=f"Project '{project}' not found." + ) + + +ProjectIdDep = Annotated[int, Depends(get_project_id)] + + +# --- Project Config Resolution (V1 API) --- + + +async def get_project_config( + project: ProjectPathDep, project_repository: ProjectRepositoryDep +) -> ProjectConfig: # pragma: no cover + """Get the current project referenced from request state. + + Args: + project: The project name from URL path + project_repository: Repository for project operations + + Returns: + The resolved project config + + Raises: + HTTPException: If project is not found + """ + # Convert project name to permalink for lookup + project_permalink = generate_permalink(str(project)) + project_obj = await project_repository.get_by_permalink(project_permalink) + if project_obj: + return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) + + # Not found + raise HTTPException( # pragma: no cover + status_code=status.HTTP_404_NOT_FOUND, detail=f"Project '{project}' not found." + ) + + +ProjectConfigDep = Annotated[ProjectConfig, Depends(get_project_config)] + + +# --- V2 API: Integer Project ID from Path --- + + +async def validate_project_id( + project_id: int, + project_repository: ProjectRepositoryDep, +) -> int: + """Validate that a numeric project ID exists in the database. + + This is used for v2 API endpoints that take project IDs as integers in the path. + The project_id parameter will be automatically extracted from the URL path by FastAPI. + + Args: + project_id: The numeric project ID from the URL path + project_repository: Repository for project operations + + Returns: + The validated project ID + + Raises: + HTTPException: If project with that ID is not found + """ + project_obj = await project_repository.get_by_id(project_id) + if not project_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project with ID {project_id} not found.", + ) + return project_id + + +ProjectIdPathDep = Annotated[int, Depends(validate_project_id)] + + +async def get_project_config_v2( + project_id: ProjectIdPathDep, project_repository: ProjectRepositoryDep +) -> ProjectConfig: # pragma: no cover + """Get the project config for v2 API (uses integer project_id from path). + + Args: + project_id: The validated numeric project ID from the URL path + project_repository: Repository for project operations + + Returns: + The resolved project config + + Raises: + HTTPException: If project is not found + """ + project_obj = await project_repository.get_by_id(project_id) + if project_obj: + return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) + + # Not found (this should not happen since ProjectIdPathDep already validates existence) + raise HTTPException( # pragma: no cover + status_code=status.HTTP_404_NOT_FOUND, detail=f"Project with ID {project_id} not found." + ) + + +ProjectConfigV2Dep = Annotated[ProjectConfig, Depends(get_project_config_v2)] + + +# --- V2 API: External UUID Project ID from Path --- + + +async def validate_project_external_id( + project_id: str, + project_repository: ProjectRepositoryDep, +) -> int: + """Validate that a project external_id (UUID) exists in the database. + + This is used for v2 API endpoints that take project external_ids as strings in the path. + The project_id parameter will be automatically extracted from the URL path by FastAPI. + + Args: + project_id: The external UUID from the URL path (named project_id for URL consistency) + project_repository: Repository for project operations + + Returns: + The internal numeric project ID (for use by repositories) + + Raises: + HTTPException: If project with that external_id is not found + """ + project_obj = await project_repository.get_by_external_id(project_id) + if not project_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project with external_id '{project_id}' not found.", + ) + return project_obj.id + + +ProjectExternalIdPathDep = Annotated[int, Depends(validate_project_external_id)] + + +async def get_project_config_v2_external( + project_id: ProjectExternalIdPathDep, project_repository: ProjectRepositoryDep +) -> ProjectConfig: # pragma: no cover + """Get the project config for v2 API (uses external_id UUID from path). + + Args: + project_id: The internal project ID resolved from external_id + project_repository: Repository for project operations + + Returns: + The resolved project config + + Raises: + HTTPException: If project is not found + """ + project_obj = await project_repository.get_by_id(project_id) + if project_obj: + return ProjectConfig(name=project_obj.name, home=pathlib.Path(project_obj.path)) + + # Not found (this should not happen since ProjectExternalIdPathDep already validates) + raise HTTPException( # pragma: no cover + status_code=status.HTTP_404_NOT_FOUND, detail=f"Project with ID {project_id} not found." + ) + + +ProjectConfigV2ExternalDep = Annotated[ + ProjectConfig, Depends(get_project_config_v2_external) +] diff --git a/src/basic_memory/deps/repositories.py b/src/basic_memory/deps/repositories.py new file mode 100644 index 00000000..24b87e90 --- /dev/null +++ b/src/basic_memory/deps/repositories.py @@ -0,0 +1,179 @@ +"""Repository dependency injection for basic-memory. + +This module provides repository dependencies: +- EntityRepository +- ObservationRepository +- RelationRepository +- SearchRepository + +Each repository is scoped to a project ID from the request. +""" + +from typing import Annotated + +from fastapi import Depends + +from basic_memory.deps.db import SessionMakerDep +from basic_memory.deps.projects import ( + ProjectIdDep, + ProjectIdPathDep, + ProjectExternalIdPathDep, +) +from basic_memory.repository.entity_repository import EntityRepository +from basic_memory.repository.observation_repository import ObservationRepository +from basic_memory.repository.relation_repository import RelationRepository +from basic_memory.repository.search_repository import SearchRepository, create_search_repository + + +# --- Entity Repository --- + + +async def get_entity_repository( + session_maker: SessionMakerDep, + project_id: ProjectIdDep, +) -> EntityRepository: + """Create an EntityRepository instance for the current project.""" + return EntityRepository(session_maker, project_id=project_id) + + +EntityRepositoryDep = Annotated[EntityRepository, Depends(get_entity_repository)] + + +async def get_entity_repository_v2( # pragma: no cover + session_maker: SessionMakerDep, + project_id: ProjectIdPathDep, +) -> EntityRepository: + """Create an EntityRepository instance for v2 API (uses integer project_id from path).""" + return EntityRepository(session_maker, project_id=project_id) + + +EntityRepositoryV2Dep = Annotated[EntityRepository, Depends(get_entity_repository_v2)] + + +async def get_entity_repository_v2_external( + session_maker: SessionMakerDep, + project_id: ProjectExternalIdPathDep, +) -> EntityRepository: + """Create an EntityRepository instance for v2 API (uses external_id from path).""" + return EntityRepository(session_maker, project_id=project_id) + + +EntityRepositoryV2ExternalDep = Annotated[EntityRepository, Depends(get_entity_repository_v2_external)] + + +# --- Observation Repository --- + + +async def get_observation_repository( + session_maker: SessionMakerDep, + project_id: ProjectIdDep, +) -> ObservationRepository: + """Create an ObservationRepository instance for the current project.""" + return ObservationRepository(session_maker, project_id=project_id) + + +ObservationRepositoryDep = Annotated[ObservationRepository, Depends(get_observation_repository)] + + +async def get_observation_repository_v2( # pragma: no cover + session_maker: SessionMakerDep, + project_id: ProjectIdPathDep, +) -> ObservationRepository: + """Create an ObservationRepository instance for v2 API.""" + return ObservationRepository(session_maker, project_id=project_id) + + +ObservationRepositoryV2Dep = Annotated[ + ObservationRepository, Depends(get_observation_repository_v2) +] + + +async def get_observation_repository_v2_external( + session_maker: SessionMakerDep, + project_id: ProjectExternalIdPathDep, +) -> ObservationRepository: + """Create an ObservationRepository instance for v2 API (uses external_id).""" + return ObservationRepository(session_maker, project_id=project_id) + + +ObservationRepositoryV2ExternalDep = Annotated[ + ObservationRepository, Depends(get_observation_repository_v2_external) +] + + +# --- Relation Repository --- + + +async def get_relation_repository( + session_maker: SessionMakerDep, + project_id: ProjectIdDep, +) -> RelationRepository: + """Create a RelationRepository instance for the current project.""" + return RelationRepository(session_maker, project_id=project_id) + + +RelationRepositoryDep = Annotated[RelationRepository, Depends(get_relation_repository)] + + +async def get_relation_repository_v2( # pragma: no cover + session_maker: SessionMakerDep, + project_id: ProjectIdPathDep, +) -> RelationRepository: + """Create a RelationRepository instance for v2 API.""" + return RelationRepository(session_maker, project_id=project_id) + + +RelationRepositoryV2Dep = Annotated[RelationRepository, Depends(get_relation_repository_v2)] + + +async def get_relation_repository_v2_external( + session_maker: SessionMakerDep, + project_id: ProjectExternalIdPathDep, +) -> RelationRepository: + """Create a RelationRepository instance for v2 API (uses external_id).""" + return RelationRepository(session_maker, project_id=project_id) + + +RelationRepositoryV2ExternalDep = Annotated[ + RelationRepository, Depends(get_relation_repository_v2_external) +] + + +# --- Search Repository --- + + +async def get_search_repository( + session_maker: SessionMakerDep, + project_id: ProjectIdDep, +) -> SearchRepository: + """Create a backend-specific SearchRepository instance for the current project. + + Uses factory function to return SQLiteSearchRepository or PostgresSearchRepository + based on database backend configuration. + """ + return create_search_repository(session_maker, project_id=project_id) + + +SearchRepositoryDep = Annotated[SearchRepository, Depends(get_search_repository)] + + +async def get_search_repository_v2( # pragma: no cover + session_maker: SessionMakerDep, + project_id: ProjectIdPathDep, +) -> SearchRepository: + """Create a SearchRepository instance for v2 API.""" + return create_search_repository(session_maker, project_id=project_id) + + +SearchRepositoryV2Dep = Annotated[SearchRepository, Depends(get_search_repository_v2)] + + +async def get_search_repository_v2_external( + session_maker: SessionMakerDep, + project_id: ProjectExternalIdPathDep, +) -> SearchRepository: + """Create a SearchRepository instance for v2 API (uses external_id).""" + return create_search_repository(session_maker, project_id=project_id) + + +SearchRepositoryV2ExternalDep = Annotated[SearchRepository, Depends(get_search_repository_v2_external)] diff --git a/src/basic_memory/deps/services.py b/src/basic_memory/deps/services.py new file mode 100644 index 00000000..e5e82c34 --- /dev/null +++ b/src/basic_memory/deps/services.py @@ -0,0 +1,480 @@ +"""Service dependency injection for basic-memory. + +This module provides service-layer dependencies: +- EntityParser, MarkdownProcessor +- FileService, EntityService +- SearchService, LinkResolver, ContextService +- SyncService, ProjectService, DirectoryService +""" + +from typing import Annotated + +from fastapi import Depends +from loguru import logger + +from basic_memory.deps.config import AppConfigDep +from basic_memory.deps.projects import ( + ProjectConfigDep, + ProjectConfigV2Dep, + ProjectConfigV2ExternalDep, + ProjectRepositoryDep, +) +from basic_memory.deps.repositories import ( + EntityRepositoryDep, + EntityRepositoryV2Dep, + EntityRepositoryV2ExternalDep, + ObservationRepositoryDep, + ObservationRepositoryV2Dep, + ObservationRepositoryV2ExternalDep, + RelationRepositoryDep, + RelationRepositoryV2Dep, + RelationRepositoryV2ExternalDep, + SearchRepositoryDep, + SearchRepositoryV2Dep, + SearchRepositoryV2ExternalDep, +) +from basic_memory.markdown import EntityParser +from basic_memory.markdown.markdown_processor import MarkdownProcessor +from basic_memory.services import EntityService, ProjectService +from basic_memory.services.context_service import ContextService +from basic_memory.services.directory_service import DirectoryService +from basic_memory.services.file_service import FileService +from basic_memory.services.link_resolver import LinkResolver +from basic_memory.services.search_service import SearchService +from basic_memory.sync import SyncService + + +# --- Entity Parser --- + + +async def get_entity_parser(project_config: ProjectConfigDep) -> EntityParser: + return EntityParser(project_config.home) + + +EntityParserDep = Annotated["EntityParser", Depends(get_entity_parser)] + + +async def get_entity_parser_v2(project_config: ProjectConfigV2Dep) -> EntityParser: # pragma: no cover + return EntityParser(project_config.home) + + +EntityParserV2Dep = Annotated["EntityParser", Depends(get_entity_parser_v2)] + + +async def get_entity_parser_v2_external(project_config: ProjectConfigV2ExternalDep) -> EntityParser: + return EntityParser(project_config.home) + + +EntityParserV2ExternalDep = Annotated["EntityParser", Depends(get_entity_parser_v2_external)] + + +# --- Markdown Processor --- + + +async def get_markdown_processor( + entity_parser: EntityParserDep, app_config: AppConfigDep +) -> MarkdownProcessor: + return MarkdownProcessor(entity_parser, app_config=app_config) + + +MarkdownProcessorDep = Annotated[MarkdownProcessor, Depends(get_markdown_processor)] + + +async def get_markdown_processor_v2( # pragma: no cover + entity_parser: EntityParserV2Dep, app_config: AppConfigDep +) -> MarkdownProcessor: + return MarkdownProcessor(entity_parser, app_config=app_config) + + +MarkdownProcessorV2Dep = Annotated[MarkdownProcessor, Depends(get_markdown_processor_v2)] + + +async def get_markdown_processor_v2_external( + entity_parser: EntityParserV2ExternalDep, app_config: AppConfigDep +) -> MarkdownProcessor: + return MarkdownProcessor(entity_parser, app_config=app_config) + + +MarkdownProcessorV2ExternalDep = Annotated[ + MarkdownProcessor, Depends(get_markdown_processor_v2_external) +] + + +# --- File Service --- + + +async def get_file_service( + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + app_config: AppConfigDep, +) -> FileService: + file_service = FileService(project_config.home, markdown_processor, app_config=app_config) + logger.debug( + f"Created FileService for project: {project_config.name}, base_path: {project_config.home} " + ) + return file_service + + +FileServiceDep = Annotated[FileService, Depends(get_file_service)] + + +async def get_file_service_v2( # pragma: no cover + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + app_config: AppConfigDep, +) -> FileService: + file_service = FileService(project_config.home, markdown_processor, app_config=app_config) + logger.debug( + f"Created FileService for project: {project_config.name}, base_path: {project_config.home}" + ) + return file_service + + +FileServiceV2Dep = Annotated[FileService, Depends(get_file_service_v2)] + + +async def get_file_service_v2_external( + project_config: ProjectConfigV2ExternalDep, + markdown_processor: MarkdownProcessorV2ExternalDep, + app_config: AppConfigDep, +) -> FileService: + file_service = FileService(project_config.home, markdown_processor, app_config=app_config) + logger.debug( + f"Created FileService for project: {project_config.name}, base_path: {project_config.home}" + ) + return file_service + + +FileServiceV2ExternalDep = Annotated[FileService, Depends(get_file_service_v2_external)] + + +# --- Search Service --- + + +async def get_search_service( + search_repository: SearchRepositoryDep, + entity_repository: EntityRepositoryDep, + file_service: FileServiceDep, +) -> SearchService: + """Create SearchService with dependencies.""" + return SearchService(search_repository, entity_repository, file_service) + + +SearchServiceDep = Annotated[SearchService, Depends(get_search_service)] + + +async def get_search_service_v2( # pragma: no cover + search_repository: SearchRepositoryV2Dep, + entity_repository: EntityRepositoryV2Dep, + file_service: FileServiceV2Dep, +) -> SearchService: + """Create SearchService for v2 API.""" + return SearchService(search_repository, entity_repository, file_service) + + +SearchServiceV2Dep = Annotated[SearchService, Depends(get_search_service_v2)] + + +async def get_search_service_v2_external( + search_repository: SearchRepositoryV2ExternalDep, + entity_repository: EntityRepositoryV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> SearchService: + """Create SearchService for v2 API (uses external_id).""" + return SearchService(search_repository, entity_repository, file_service) + + +SearchServiceV2ExternalDep = Annotated[SearchService, Depends(get_search_service_v2_external)] + + +# --- Link Resolver --- + + +async def get_link_resolver( + entity_repository: EntityRepositoryDep, search_service: SearchServiceDep +) -> LinkResolver: + return LinkResolver(entity_repository=entity_repository, search_service=search_service) + + +LinkResolverDep = Annotated[LinkResolver, Depends(get_link_resolver)] + + +async def get_link_resolver_v2( # pragma: no cover + entity_repository: EntityRepositoryV2Dep, search_service: SearchServiceV2Dep +) -> LinkResolver: + return LinkResolver(entity_repository=entity_repository, search_service=search_service) + + +LinkResolverV2Dep = Annotated[LinkResolver, Depends(get_link_resolver_v2)] + + +async def get_link_resolver_v2_external( + entity_repository: EntityRepositoryV2ExternalDep, search_service: SearchServiceV2ExternalDep +) -> LinkResolver: + return LinkResolver(entity_repository=entity_repository, search_service=search_service) + + +LinkResolverV2ExternalDep = Annotated[LinkResolver, Depends(get_link_resolver_v2_external)] + + +# --- Entity Service --- + + +async def get_entity_service( + entity_repository: EntityRepositoryDep, + observation_repository: ObservationRepositoryDep, + relation_repository: RelationRepositoryDep, + entity_parser: EntityParserDep, + file_service: FileServiceDep, + link_resolver: LinkResolverDep, + search_service: SearchServiceDep, + app_config: AppConfigDep, +) -> EntityService: + """Create EntityService with repository.""" + return EntityService( + entity_repository=entity_repository, + observation_repository=observation_repository, + relation_repository=relation_repository, + entity_parser=entity_parser, + file_service=file_service, + link_resolver=link_resolver, + search_service=search_service, + app_config=app_config, + ) + + +EntityServiceDep = Annotated[EntityService, Depends(get_entity_service)] + + +async def get_entity_service_v2( # pragma: no cover + entity_repository: EntityRepositoryV2Dep, + observation_repository: ObservationRepositoryV2Dep, + relation_repository: RelationRepositoryV2Dep, + entity_parser: EntityParserV2Dep, + file_service: FileServiceV2Dep, + link_resolver: LinkResolverV2Dep, + search_service: SearchServiceV2Dep, + app_config: AppConfigDep, +) -> EntityService: + """Create EntityService for v2 API.""" + return EntityService( + entity_repository=entity_repository, + observation_repository=observation_repository, + relation_repository=relation_repository, + entity_parser=entity_parser, + file_service=file_service, + link_resolver=link_resolver, + search_service=search_service, + app_config=app_config, + ) + + +EntityServiceV2Dep = Annotated[EntityService, Depends(get_entity_service_v2)] + + +async def get_entity_service_v2_external( + entity_repository: EntityRepositoryV2ExternalDep, + observation_repository: ObservationRepositoryV2ExternalDep, + relation_repository: RelationRepositoryV2ExternalDep, + entity_parser: EntityParserV2ExternalDep, + file_service: FileServiceV2ExternalDep, + link_resolver: LinkResolverV2ExternalDep, + search_service: SearchServiceV2ExternalDep, + app_config: AppConfigDep, +) -> EntityService: + """Create EntityService for v2 API (uses external_id).""" + return EntityService( + entity_repository=entity_repository, + observation_repository=observation_repository, + relation_repository=relation_repository, + entity_parser=entity_parser, + file_service=file_service, + link_resolver=link_resolver, + search_service=search_service, + app_config=app_config, + ) + + +EntityServiceV2ExternalDep = Annotated[EntityService, Depends(get_entity_service_v2_external)] + + +# --- Context Service --- + + +async def get_context_service( + search_repository: SearchRepositoryDep, + entity_repository: EntityRepositoryDep, + observation_repository: ObservationRepositoryDep, +) -> ContextService: + return ContextService( + search_repository=search_repository, + entity_repository=entity_repository, + observation_repository=observation_repository, + ) + + +ContextServiceDep = Annotated[ContextService, Depends(get_context_service)] + + +async def get_context_service_v2( # pragma: no cover + search_repository: SearchRepositoryV2Dep, + entity_repository: EntityRepositoryV2Dep, + observation_repository: ObservationRepositoryV2Dep, +) -> ContextService: + """Create ContextService for v2 API.""" + return ContextService( + search_repository=search_repository, + entity_repository=entity_repository, + observation_repository=observation_repository, + ) + + +ContextServiceV2Dep = Annotated[ContextService, Depends(get_context_service_v2)] + + +async def get_context_service_v2_external( + search_repository: SearchRepositoryV2ExternalDep, + entity_repository: EntityRepositoryV2ExternalDep, + observation_repository: ObservationRepositoryV2ExternalDep, +) -> ContextService: + """Create ContextService for v2 API (uses external_id).""" + return ContextService( + search_repository=search_repository, + entity_repository=entity_repository, + observation_repository=observation_repository, + ) + + +ContextServiceV2ExternalDep = Annotated[ContextService, Depends(get_context_service_v2_external)] + + +# --- Sync Service --- + + +async def get_sync_service( + app_config: AppConfigDep, + entity_service: EntityServiceDep, + entity_parser: EntityParserDep, + entity_repository: EntityRepositoryDep, + relation_repository: RelationRepositoryDep, + project_repository: ProjectRepositoryDep, + search_service: SearchServiceDep, + file_service: FileServiceDep, +) -> SyncService: # pragma: no cover + return SyncService( + app_config=app_config, + entity_service=entity_service, + entity_parser=entity_parser, + entity_repository=entity_repository, + relation_repository=relation_repository, + project_repository=project_repository, + search_service=search_service, + file_service=file_service, + ) + + +SyncServiceDep = Annotated[SyncService, Depends(get_sync_service)] + + +async def get_sync_service_v2( + app_config: AppConfigDep, + entity_service: EntityServiceV2Dep, + entity_parser: EntityParserV2Dep, + entity_repository: EntityRepositoryV2Dep, + relation_repository: RelationRepositoryV2Dep, + project_repository: ProjectRepositoryDep, + search_service: SearchServiceV2Dep, + file_service: FileServiceV2Dep, +) -> SyncService: # pragma: no cover + """Create SyncService for v2 API.""" + return SyncService( + app_config=app_config, + entity_service=entity_service, + entity_parser=entity_parser, + entity_repository=entity_repository, + relation_repository=relation_repository, + project_repository=project_repository, + search_service=search_service, + file_service=file_service, + ) + + +SyncServiceV2Dep = Annotated[SyncService, Depends(get_sync_service_v2)] + + +async def get_sync_service_v2_external( + app_config: AppConfigDep, + entity_service: EntityServiceV2ExternalDep, + entity_parser: EntityParserV2ExternalDep, + entity_repository: EntityRepositoryV2ExternalDep, + relation_repository: RelationRepositoryV2ExternalDep, + project_repository: ProjectRepositoryDep, + search_service: SearchServiceV2ExternalDep, + file_service: FileServiceV2ExternalDep, +) -> SyncService: # pragma: no cover + """Create SyncService for v2 API (uses external_id).""" + return SyncService( + app_config=app_config, + entity_service=entity_service, + entity_parser=entity_parser, + entity_repository=entity_repository, + relation_repository=relation_repository, + project_repository=project_repository, + search_service=search_service, + file_service=file_service, + ) + + +SyncServiceV2ExternalDep = Annotated[SyncService, Depends(get_sync_service_v2_external)] + + +# --- Project Service --- + + +async def get_project_service( + project_repository: ProjectRepositoryDep, +) -> ProjectService: + """Create ProjectService with repository.""" + return ProjectService(repository=project_repository) + + +ProjectServiceDep = Annotated[ProjectService, Depends(get_project_service)] + + +# --- Directory Service --- + + +async def get_directory_service( + entity_repository: EntityRepositoryDep, +) -> DirectoryService: + """Create DirectoryService with dependencies.""" + return DirectoryService( + entity_repository=entity_repository, + ) + + +DirectoryServiceDep = Annotated[DirectoryService, Depends(get_directory_service)] + + +async def get_directory_service_v2( # pragma: no cover + entity_repository: EntityRepositoryV2Dep, +) -> DirectoryService: + """Create DirectoryService for v2 API (uses integer project_id from path).""" + return DirectoryService( + entity_repository=entity_repository, + ) + + +DirectoryServiceV2Dep = Annotated[DirectoryService, Depends(get_directory_service_v2)] + + +async def get_directory_service_v2_external( + entity_repository: EntityRepositoryV2ExternalDep, +) -> DirectoryService: + """Create DirectoryService for v2 API (uses external_id from path).""" + return DirectoryService( + entity_repository=entity_repository, + ) + + +DirectoryServiceV2ExternalDep = Annotated[DirectoryService, Depends(get_directory_service_v2_external)] diff --git a/src/basic_memory/importers/chatgpt_importer.py b/src/basic_memory/importers/chatgpt_importer.py index 7ddc2329..1184c1a3 100644 --- a/src/basic_memory/importers/chatgpt_importer.py +++ b/src/basic_memory/importers/chatgpt_importer.py @@ -15,7 +15,7 @@ class ChatGPTImporter(Importer[ChatImportResult]): """Service for importing ChatGPT conversations.""" - def handle_error( + def handle_error( # pragma: no cover self, message: str, error: Optional[Exception] = None ) -> ChatImportResult: """Return a failed ChatImportResult with an error message.""" diff --git a/src/basic_memory/importers/claude_conversations_importer.py b/src/basic_memory/importers/claude_conversations_importer.py index 9682d771..5e89ad0e 100644 --- a/src/basic_memory/importers/claude_conversations_importer.py +++ b/src/basic_memory/importers/claude_conversations_importer.py @@ -15,7 +15,7 @@ class ClaudeConversationsImporter(Importer[ChatImportResult]): """Service for importing Claude conversations.""" - def handle_error( + def handle_error( # pragma: no cover self, message: str, error: Optional[Exception] = None ) -> ChatImportResult: """Return a failed ChatImportResult with an error message.""" diff --git a/src/basic_memory/importers/claude_projects_importer.py b/src/basic_memory/importers/claude_projects_importer.py index 549f1ca0..3e914083 100644 --- a/src/basic_memory/importers/claude_projects_importer.py +++ b/src/basic_memory/importers/claude_projects_importer.py @@ -14,7 +14,7 @@ class ClaudeProjectsImporter(Importer[ProjectImportResult]): """Service for importing Claude projects.""" - def handle_error( + def handle_error( # pragma: no cover self, message: str, error: Optional[Exception] = None ) -> ProjectImportResult: """Return a failed ProjectImportResult with an error message.""" diff --git a/src/basic_memory/importers/memory_json_importer.py b/src/basic_memory/importers/memory_json_importer.py index 24f4ee53..579a7db2 100644 --- a/src/basic_memory/importers/memory_json_importer.py +++ b/src/basic_memory/importers/memory_json_importer.py @@ -13,7 +13,7 @@ class MemoryJsonImporter(Importer[EntityImportResult]): """Service for importing memory.json format data.""" - def handle_error( + def handle_error( # pragma: no cover self, message: str, error: Optional[Exception] = None ) -> EntityImportResult: """Return a failed EntityImportResult with an error message.""" diff --git a/src/basic_memory/mcp/clients/__init__.py b/src/basic_memory/mcp/clients/__init__.py new file mode 100644 index 00000000..c2705ecb --- /dev/null +++ b/src/basic_memory/mcp/clients/__init__.py @@ -0,0 +1,28 @@ +"""Typed internal API clients for MCP tools. + +These clients encapsulate API paths, error handling, and response validation. +MCP tools become thin adapters that call these clients and format results. + +Usage: + from basic_memory.mcp.clients import KnowledgeClient, SearchClient + + async with get_client() as http_client: + knowledge = KnowledgeClient(http_client, project_id) + entity = await knowledge.create_entity(entity_data) +""" + +from basic_memory.mcp.clients.knowledge import KnowledgeClient +from basic_memory.mcp.clients.search import SearchClient +from basic_memory.mcp.clients.memory import MemoryClient +from basic_memory.mcp.clients.directory import DirectoryClient +from basic_memory.mcp.clients.resource import ResourceClient +from basic_memory.mcp.clients.project import ProjectClient + +__all__ = [ + "KnowledgeClient", + "SearchClient", + "MemoryClient", + "DirectoryClient", + "ResourceClient", + "ProjectClient", +] diff --git a/src/basic_memory/mcp/clients/directory.py b/src/basic_memory/mcp/clients/directory.py new file mode 100644 index 00000000..5444aa62 --- /dev/null +++ b/src/basic_memory/mcp/clients/directory.py @@ -0,0 +1,70 @@ +"""Typed client for directory API operations. + +Encapsulates all /v2/projects/{project_id}/directory/* endpoints. +""" + +from typing import Optional, Any + +from httpx import AsyncClient + +from basic_memory.mcp.tools.utils import call_get + + +class DirectoryClient: + """Typed client for directory listing operations. + + Centralizes: + - API path construction for /v2/projects/{project_id}/directory/* + - Response validation + - Consistent error handling through call_* utilities + + Usage: + async with get_client() as http_client: + client = DirectoryClient(http_client, project_id) + nodes = await client.list("/", depth=2) + """ + + def __init__(self, http_client: AsyncClient, project_id: str): + """Initialize the directory client. + + Args: + http_client: HTTPX AsyncClient for making requests + project_id: Project external_id (UUID) for API calls + """ + self.http_client = http_client + self.project_id = project_id + self._base_path = f"/v2/projects/{project_id}/directory" + + async def list( + self, + dir_name: str = "/", + *, + depth: int = 1, + file_name_glob: Optional[str] = None, + ) -> list[dict[str, Any]]: + """List directory contents. + + Args: + dir_name: Directory path to list (default: root) + depth: How deep to traverse (default: 1) + file_name_glob: Optional glob pattern to filter files + + Returns: + List of directory nodes with their contents + + Raises: + ToolError: If the request fails + """ + params: dict = { + "dir_name": dir_name, + "depth": depth, + } + if file_name_glob: + params["file_name_glob"] = file_name_glob + + response = await call_get( + self.http_client, + f"{self._base_path}/list", + params=params, + ) + return response.json() diff --git a/src/basic_memory/mcp/clients/knowledge.py b/src/basic_memory/mcp/clients/knowledge.py new file mode 100644 index 00000000..a0e2eb6a --- /dev/null +++ b/src/basic_memory/mcp/clients/knowledge.py @@ -0,0 +1,176 @@ +"""Typed client for knowledge/entity API operations. + +Encapsulates all /v2/projects/{project_id}/knowledge/* endpoints. +""" + +from typing import Any + +from httpx import AsyncClient + +from basic_memory.mcp.tools.utils import call_get, call_post, call_put, call_patch, call_delete +from basic_memory.schemas.response import EntityResponse, DeleteEntitiesResponse + + +class KnowledgeClient: + """Typed client for knowledge graph entity operations. + + Centralizes: + - API path construction for /v2/projects/{project_id}/knowledge/* + - Response validation via Pydantic models + - Consistent error handling through call_* utilities + + Usage: + async with get_client() as http_client: + client = KnowledgeClient(http_client, project_id) + entity = await client.create_entity(entity_data) + """ + + def __init__(self, http_client: AsyncClient, project_id: str): + """Initialize the knowledge client. + + Args: + http_client: HTTPX AsyncClient for making requests + project_id: Project external_id (UUID) for API calls + """ + self.http_client = http_client + self.project_id = project_id + self._base_path = f"/v2/projects/{project_id}/knowledge" + + # --- Entity CRUD Operations --- + + async def create_entity(self, entity_data: dict[str, Any]) -> EntityResponse: + """Create a new entity. + + Args: + entity_data: Entity data including title, content, folder, etc. + + Returns: + EntityResponse with created entity details + + Raises: + ToolError: If the request fails + """ + response = await call_post( + self.http_client, + f"{self._base_path}/entities", + json=entity_data, + ) + return EntityResponse.model_validate(response.json()) + + async def update_entity(self, entity_id: str, entity_data: dict[str, Any]) -> EntityResponse: + """Update an existing entity (full replacement). + + Args: + entity_id: Entity external_id (UUID) + entity_data: Complete entity data for replacement + + Returns: + EntityResponse with updated entity details + + Raises: + ToolError: If the request fails + """ + response = await call_put( + self.http_client, + f"{self._base_path}/entities/{entity_id}", + json=entity_data, + ) + return EntityResponse.model_validate(response.json()) + + async def get_entity(self, entity_id: str) -> EntityResponse: + """Get an entity by ID. + + Args: + entity_id: Entity external_id (UUID) + + Returns: + EntityResponse with entity details + + Raises: + ToolError: If the entity is not found or request fails + """ + response = await call_get( + self.http_client, + f"{self._base_path}/entities/{entity_id}", + ) + return EntityResponse.model_validate(response.json()) + + async def patch_entity(self, entity_id: str, patch_data: dict[str, Any]) -> EntityResponse: + """Partially update an entity. + + Args: + entity_id: Entity external_id (UUID) + patch_data: Partial entity data to update + + Returns: + EntityResponse with updated entity details + + Raises: + ToolError: If the request fails + """ + response = await call_patch( + self.http_client, + f"{self._base_path}/entities/{entity_id}", + json=patch_data, + ) + return EntityResponse.model_validate(response.json()) + + async def delete_entity(self, entity_id: str) -> DeleteEntitiesResponse: + """Delete an entity. + + Args: + entity_id: Entity external_id (UUID) + + Returns: + DeleteEntitiesResponse confirming deletion + + Raises: + ToolError: If the entity is not found or request fails + """ + response = await call_delete( + self.http_client, + f"{self._base_path}/entities/{entity_id}", + ) + return DeleteEntitiesResponse.model_validate(response.json()) + + async def move_entity(self, entity_id: str, destination_path: str) -> EntityResponse: + """Move an entity to a new location. + + Args: + entity_id: Entity external_id (UUID) + destination_path: New file path for the entity + + Returns: + EntityResponse with updated entity details + + Raises: + ToolError: If the request fails + """ + response = await call_put( + self.http_client, + f"{self._base_path}/entities/{entity_id}/move", + json={"destination_path": destination_path}, + ) + return EntityResponse.model_validate(response.json()) + + # --- Resolution --- + + async def resolve_entity(self, identifier: str) -> str: + """Resolve a string identifier to an entity external_id. + + Args: + identifier: The identifier to resolve (permalink, title, or path) + + Returns: + The resolved entity external_id (UUID) + + Raises: + ToolError: If the identifier cannot be resolved + """ + response = await call_post( + self.http_client, + f"{self._base_path}/resolve", + json={"identifier": identifier}, + ) + data = response.json() + return data["external_id"] diff --git a/src/basic_memory/mcp/clients/memory.py b/src/basic_memory/mcp/clients/memory.py new file mode 100644 index 00000000..de42c63e --- /dev/null +++ b/src/basic_memory/mcp/clients/memory.py @@ -0,0 +1,120 @@ +"""Typed client for memory/context API operations. + +Encapsulates all /v2/projects/{project_id}/memory/* endpoints. +""" + +from typing import Optional + +from httpx import AsyncClient + +from basic_memory.mcp.tools.utils import call_get +from basic_memory.schemas.memory import GraphContext + + +class MemoryClient: + """Typed client for memory context operations. + + Centralizes: + - API path construction for /v2/projects/{project_id}/memory/* + - Response validation via Pydantic models + - Consistent error handling through call_* utilities + + Usage: + async with get_client() as http_client: + client = MemoryClient(http_client, project_id) + context = await client.build_context("memory://specs/search") + """ + + def __init__(self, http_client: AsyncClient, project_id: str): + """Initialize the memory client. + + Args: + http_client: HTTPX AsyncClient for making requests + project_id: Project external_id (UUID) for API calls + """ + self.http_client = http_client + self.project_id = project_id + self._base_path = f"/v2/projects/{project_id}/memory" + + async def build_context( + self, + path: str, + *, + depth: int = 1, + timeframe: Optional[str] = None, + page: int = 1, + page_size: int = 10, + max_related: int = 10, + ) -> GraphContext: + """Build context from a memory path. + + Args: + path: The path to build context for (without memory:// prefix) + depth: How deep to traverse relations + timeframe: Time filter (e.g., "7d", "1 week") + page: Page number (1-indexed) + page_size: Results per page + max_related: Maximum related items per result + + Returns: + GraphContext with hierarchical results + + Raises: + ToolError: If the request fails + """ + params: dict = { + "depth": depth, + "page": page, + "page_size": page_size, + "max_related": max_related, + } + if timeframe: + params["timeframe"] = timeframe + + response = await call_get( + self.http_client, + f"{self._base_path}/{path}", + params=params, + ) + return GraphContext.model_validate(response.json()) + + async def recent( + self, + *, + timeframe: str = "7d", + depth: int = 1, + types: Optional[list[str]] = None, + page: int = 1, + page_size: int = 10, + ) -> GraphContext: + """Get recent activity. + + Args: + timeframe: Time filter (e.g., "7d", "1 week", "2 days ago") + depth: How deep to traverse relations + types: Filter by item types + page: Page number (1-indexed) + page_size: Results per page + + Returns: + GraphContext with recent activity + + Raises: + ToolError: If the request fails + """ + params: dict = { + "timeframe": timeframe, + "depth": depth, + "page": page, + "page_size": page_size, + } + if types: + # Join types as comma-separated string if provided + params["type"] = ",".join(types) if isinstance(types, list) else types + + response = await call_get( + self.http_client, + f"{self._base_path}/recent", + params=params, + ) + return GraphContext.model_validate(response.json()) diff --git a/src/basic_memory/mcp/clients/project.py b/src/basic_memory/mcp/clients/project.py new file mode 100644 index 00000000..89500490 --- /dev/null +++ b/src/basic_memory/mcp/clients/project.py @@ -0,0 +1,89 @@ +"""Typed client for project API operations. + +Encapsulates project-level endpoints. +""" + +from typing import Any + +from httpx import AsyncClient + +from basic_memory.mcp.tools.utils import call_get, call_post, call_delete +from basic_memory.schemas.project_info import ProjectList, ProjectStatusResponse + + +class ProjectClient: + """Typed client for project management operations. + + Centralizes: + - API path construction for project endpoints + - Response validation via Pydantic models + - Consistent error handling through call_* utilities + + Note: This client does not require a project_id since it operates + across projects. + + Usage: + async with get_client() as http_client: + client = ProjectClient(http_client) + projects = await client.list_projects() + """ + + def __init__(self, http_client: AsyncClient): + """Initialize the project client. + + Args: + http_client: HTTPX AsyncClient for making requests + """ + self.http_client = http_client + + async def list_projects(self) -> ProjectList: + """List all available projects. + + Returns: + ProjectList with all projects and default project name + + Raises: + ToolError: If the request fails + """ + response = await call_get( + self.http_client, + "/projects/projects", + ) + return ProjectList.model_validate(response.json()) + + async def create_project(self, project_data: dict[str, Any]) -> ProjectStatusResponse: + """Create a new project. + + Args: + project_data: Project creation data (name, path, set_default) + + Returns: + ProjectStatusResponse with creation result + + Raises: + ToolError: If the request fails + """ + response = await call_post( + self.http_client, + "/projects/projects", + json=project_data, + ) + return ProjectStatusResponse.model_validate(response.json()) + + async def delete_project(self, project_external_id: str) -> ProjectStatusResponse: + """Delete a project by its external ID. + + Args: + project_external_id: Project external ID (UUID) + + Returns: + ProjectStatusResponse with deletion result + + Raises: + ToolError: If the request fails + """ + response = await call_delete( + self.http_client, + f"/v2/projects/{project_external_id}", + ) + return ProjectStatusResponse.model_validate(response.json()) diff --git a/src/basic_memory/mcp/clients/resource.py b/src/basic_memory/mcp/clients/resource.py new file mode 100644 index 00000000..cc44807b --- /dev/null +++ b/src/basic_memory/mcp/clients/resource.py @@ -0,0 +1,71 @@ +"""Typed client for resource API operations. + +Encapsulates all /v2/projects/{project_id}/resource/* endpoints. +""" + +from typing import Optional + +from httpx import AsyncClient, Response + +from basic_memory.mcp.tools.utils import call_get + + +class ResourceClient: + """Typed client for resource operations. + + Centralizes: + - API path construction for /v2/projects/{project_id}/resource/* + - Consistent error handling through call_* utilities + + Note: This client returns raw Response objects for resources since they + may be text, images, or other binary content that needs special handling. + + Usage: + async with get_client() as http_client: + client = ResourceClient(http_client, project_id) + response = await client.read(entity_id) + text = response.text + """ + + def __init__(self, http_client: AsyncClient, project_id: str): + """Initialize the resource client. + + Args: + http_client: HTTPX AsyncClient for making requests + project_id: Project external_id (UUID) for API calls + """ + self.http_client = http_client + self.project_id = project_id + self._base_path = f"/v2/projects/{project_id}/resource" + + async def read( + self, + entity_id: str, + *, + page: Optional[int] = None, + page_size: Optional[int] = None, + ) -> Response: + """Read a resource by entity ID. + + Args: + entity_id: Entity external_id (UUID) + page: Optional page number for paginated content + page_size: Optional page size for paginated content + + Returns: + Raw HTTP Response (caller handles text/binary content) + + Raises: + ToolError: If the resource is not found or request fails + """ + params: dict = {} + if page is not None: + params["page"] = page + if page_size is not None: + params["page_size"] = page_size + + return await call_get( + self.http_client, + f"{self._base_path}/{entity_id}", + params=params if params else None, + ) diff --git a/src/basic_memory/mcp/clients/search.py b/src/basic_memory/mcp/clients/search.py new file mode 100644 index 00000000..a7c3ccb9 --- /dev/null +++ b/src/basic_memory/mcp/clients/search.py @@ -0,0 +1,65 @@ +"""Typed client for search API operations. + +Encapsulates all /v2/projects/{project_id}/search/* endpoints. +""" + +from typing import Any + +from httpx import AsyncClient + +from basic_memory.mcp.tools.utils import call_post +from basic_memory.schemas.search import SearchResponse + + +class SearchClient: + """Typed client for search operations. + + Centralizes: + - API path construction for /v2/projects/{project_id}/search/* + - Response validation via Pydantic models + - Consistent error handling through call_* utilities + + Usage: + async with get_client() as http_client: + client = SearchClient(http_client, project_id) + results = await client.search(search_query.model_dump()) + """ + + def __init__(self, http_client: AsyncClient, project_id: str): + """Initialize the search client. + + Args: + http_client: HTTPX AsyncClient for making requests + project_id: Project external_id (UUID) for API calls + """ + self.http_client = http_client + self.project_id = project_id + self._base_path = f"/v2/projects/{project_id}/search" + + async def search( + self, + query: dict[str, Any], + *, + page: int = 1, + page_size: int = 10, + ) -> SearchResponse: + """Search across all content in the knowledge base. + + Args: + query: Search query dict (from SearchQuery.model_dump()) + page: Page number (1-indexed) + page_size: Results per page + + Returns: + SearchResponse with results and pagination + + Raises: + ToolError: If the request fails + """ + response = await call_post( + self.http_client, + f"{self._base_path}/", + json=query, + params={"page": page, "page_size": page_size}, + ) + return SearchResponse.model_validate(response.json()) diff --git a/src/basic_memory/mcp/container.py b/src/basic_memory/mcp/container.py new file mode 100644 index 00000000..cd48f002 --- /dev/null +++ b/src/basic_memory/mcp/container.py @@ -0,0 +1,110 @@ +"""MCP composition root for Basic Memory. + +This container owns reading ConfigManager and environment variables for the +MCP server entrypoint. Downstream modules receive config/dependencies explicitly +rather than reading globals. + +Design principles: +- Only this module reads ConfigManager directly +- Runtime mode (cloud/local/test) is resolved here +- File sync decisions are centralized here +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from basic_memory.config import BasicMemoryConfig, ConfigManager +from basic_memory.runtime import RuntimeMode, resolve_runtime_mode + +if TYPE_CHECKING: # pragma: no cover + from basic_memory.sync import SyncCoordinator + + +@dataclass +class McpContainer: + """Composition root for the MCP server entrypoint. + + Holds resolved configuration and runtime context. + Created once at server startup, then used to wire dependencies. + """ + + config: BasicMemoryConfig + mode: RuntimeMode + + @classmethod + def create(cls) -> "McpContainer": + """Create container by reading ConfigManager. + + This is the single point where MCP reads global config. + """ + config = ConfigManager().config + mode = resolve_runtime_mode( + cloud_mode_enabled=config.cloud_mode_enabled, + is_test_env=config.is_test_env, + ) + return cls(config=config, mode=mode) + + # --- Runtime Mode Properties --- + + @property + def should_sync_files(self) -> bool: + """Whether local file sync should be started. + + Sync is enabled when: + - sync_changes is True in config + - Not in test mode (tests manage their own sync) + - Not in cloud mode (cloud handles sync differently) + """ + return ( + self.config.sync_changes and not self.mode.is_test and not self.mode.is_cloud + ) + + @property + def sync_skip_reason(self) -> str | None: + """Reason why sync is skipped, or None if sync should run. + + Useful for logging why sync was disabled. + """ + if self.mode.is_test: + return "Test environment detected" + if self.mode.is_cloud: + return "Cloud mode enabled" + if not self.config.sync_changes: + return "Sync changes disabled" + return None + + def create_sync_coordinator(self) -> "SyncCoordinator": + """Create a SyncCoordinator with this container's settings. + + Returns: + SyncCoordinator configured for this runtime environment + """ + # Deferred import to avoid circular dependency + from basic_memory.sync import SyncCoordinator + + return SyncCoordinator( + config=self.config, + should_sync=self.should_sync_files, + skip_reason=self.sync_skip_reason, + ) + + +# Module-level container instance (set by lifespan) +_container: McpContainer | None = None + + +def get_container() -> McpContainer: + """Get the current MCP container. + + Raises: + RuntimeError: If container hasn't been initialized + """ + if _container is None: + raise RuntimeError("MCP container not initialized. Call set_container() first.") + return _container + + +def set_container(container: McpContainer) -> None: + """Set the MCP container (called by lifespan).""" + global _container + _container = container diff --git a/src/basic_memory/mcp/project_context.py b/src/basic_memory/mcp/project_context.py index 0708dfeb..786ff699 100644 --- a/src/basic_memory/mcp/project_context.py +++ b/src/basic_memory/mcp/project_context.py @@ -2,9 +2,12 @@ Provides project lookup utilities for MCP tools. Handles project validation and context management in one place. + +Note: This module uses ProjectResolver for unified project resolution. +The resolve_project_parameter function is a thin wrapper for backwards +compatibility with existing MCP tools. """ -import os from typing import Optional, List from httpx import AsyncClient from httpx._types import ( @@ -14,16 +17,25 @@ from fastmcp import Context from basic_memory.config import ConfigManager +from basic_memory.project_resolver import ProjectResolver from basic_memory.schemas.project_info import ProjectItem, ProjectList from basic_memory.utils import generate_permalink async def resolve_project_parameter( - project: Optional[str] = None, allow_discovery: bool = False + project: Optional[str] = None, + allow_discovery: bool = False, + cloud_mode: Optional[bool] = None, + default_project_mode: Optional[bool] = None, + default_project: Optional[str] = None, ) -> Optional[str]: """Resolve project parameter using three-tier hierarchy. - if config.cloud_mode: + This is a thin wrapper around ProjectResolver for backwards compatibility. + New code should consider using ProjectResolver directly for more detailed + resolution information. + + if cloud_mode: project is required (unless allow_discovery=True for tools that support discovery mode) else: Resolution order: @@ -35,41 +47,31 @@ async def resolve_project_parameter( project: Optional explicit project parameter allow_discovery: If True, allows returning None in cloud mode for discovery mode (used by tools like recent_activity that can operate across all projects) + cloud_mode: Optional explicit cloud mode. If not provided, reads from ConfigManager. + default_project_mode: Optional explicit default project mode. If not provided, reads from ConfigManager. + default_project: Optional explicit default project. If not provided, reads from ConfigManager. Returns: Resolved project name or None if no resolution possible """ - - config = ConfigManager().config - # if cloud_mode, project is required (unless discovery mode is allowed) - if config.cloud_mode: - if project: - logger.debug(f"project: {project}, cloud_mode: {config.cloud_mode}") - return project - elif allow_discovery: - logger.debug("cloud_mode: discovery mode allowed, returning None") - return None - else: - raise ValueError("No project specified. Project is required for cloud mode.") - - # Priority 1: CLI constraint overrides everything (--project arg sets env var) - constrained_project = os.environ.get("BASIC_MEMORY_MCP_PROJECT") - if constrained_project: - logger.debug(f"Using CLI constrained project: {constrained_project}") - return constrained_project - - # Priority 2: Explicit project parameter - if project: - logger.debug(f"Using explicit project parameter: {project}") - return project - - # Priority 3: Default project mode - if config.default_project_mode: - logger.debug(f"Using default project from config: {config.default_project}") - return config.default_project - - # No resolution possible - return None + # Load config for any values not explicitly provided + if cloud_mode is None or default_project_mode is None or default_project is None: + config = ConfigManager().config + if cloud_mode is None: + cloud_mode = config.cloud_mode + if default_project_mode is None: + default_project_mode = config.default_project_mode + if default_project is None: + default_project = config.default_project + + # Create resolver with configuration and resolve + resolver = ProjectResolver.from_env( + cloud_mode=cloud_mode, + default_project_mode=default_project_mode, + default_project=default_project, + ) + result = resolver.resolve(project=project, allow_discovery=allow_discovery) + return result.project async def get_project_names(client: AsyncClient, headers: HeaderTypes | None = None) -> List[str]: diff --git a/src/basic_memory/mcp/server.py b/src/basic_memory/mcp/server.py index a3e3a463..ab99d503 100644 --- a/src/basic_memory/mcp/server.py +++ b/src/basic_memory/mcp/server.py @@ -2,15 +2,14 @@ Basic Memory FastMCP server. """ -import asyncio from contextlib import asynccontextmanager from fastmcp import FastMCP from loguru import logger from basic_memory import db -from basic_memory.config import ConfigManager -from basic_memory.services.initialization import initialize_app, initialize_file_sync +from basic_memory.mcp.container import McpContainer, set_container +from basic_memory.services.initialization import initialize_app from basic_memory.telemetry import show_notice_if_needed, track_app_started @@ -21,11 +20,15 @@ async def lifespan(app: FastMCP): Handles: - Database initialization and migrations - Telemetry notice and tracking - - File sync in background (if enabled and not in cloud mode) + - File sync via SyncCoordinator (if enabled and not in cloud mode) - Proper cleanup on shutdown """ - app_config = ConfigManager().config - logger.info("Starting Basic Memory MCP server") + # --- Composition Root --- + # Create container and read config (single point of config access) + container = McpContainer.create() + set_container(container) + + logger.info(f"Starting Basic Memory MCP server (mode={container.mode.name})") # Show telemetry notice (first run only) and track startup show_notice_if_needed() @@ -37,35 +40,18 @@ async def lifespan(app: FastMCP): engine_was_none = db._engine is None # Initialize app (runs migrations, reconciles projects) - await initialize_app(app_config) - - # Start file sync as background task (if enabled and not in cloud mode) - sync_task = None - if app_config.is_test_env: - logger.info("Test environment detected - skipping local file sync") - elif app_config.sync_changes and not app_config.cloud_mode_enabled: # pragma: no cover - logger.info("Starting file sync in background") - - async def _file_sync_runner() -> None: - await initialize_file_sync(app_config) + await initialize_app(container.config) - sync_task = asyncio.create_task(_file_sync_runner()) - elif app_config.cloud_mode_enabled: # pragma: no cover - logger.info("Cloud mode enabled - skipping local file sync") - else: # pragma: no cover - logger.info("Sync changes disabled - skipping file sync") + # Create and start sync coordinator (lifecycle centralized in coordinator) + sync_coordinator = container.create_sync_coordinator() + await sync_coordinator.start() try: yield finally: - # Shutdown + # Shutdown - coordinator handles clean task cancellation logger.info("Shutting down Basic Memory MCP server") - if sync_task: # pragma: no cover - sync_task.cancel() - try: - await sync_task - except asyncio.CancelledError: - logger.info("File sync task cancelled") + await sync_coordinator.stop() # Only shutdown DB if we created it (not if test fixture provided it) if engine_was_none: diff --git a/src/basic_memory/mcp/tools/build_context.py b/src/basic_memory/mcp/tools/build_context.py index 437fd448..0a65f633 100644 --- a/src/basic_memory/mcp/tools/build_context.py +++ b/src/basic_memory/mcp/tools/build_context.py @@ -8,7 +8,6 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.project_context import get_active_project from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_get from basic_memory.telemetry import track_mcp_tool from basic_memory.schemas.base import TimeFrame from basic_memory.schemas.memory import ( @@ -106,15 +105,16 @@ async def build_context( # Get the active project using the new stateless approach active_project = await get_active_project(client, project, context) - response = await call_get( - client, - f"/v2/projects/{active_project.external_id}/memory/{memory_url_path(url)}", - params={ - "depth": depth, - "timeframe": timeframe, - "page": page, - "page_size": page_size, - "max_related": max_related, - }, + # Import here to avoid circular import + from basic_memory.mcp.clients import MemoryClient + + # Use typed MemoryClient for API calls + memory_client = MemoryClient(client, active_project.external_id) + return await memory_client.build_context( + memory_url_path(url), + depth=depth or 1, + timeframe=timeframe, + page=page, + page_size=page_size, + max_related=max_related, ) - return GraphContext.model_validate(response.json()) diff --git a/src/basic_memory/mcp/tools/delete_note.py b/src/basic_memory/mcp/tools/delete_note.py index b252fc6a..4db824c6 100644 --- a/src/basic_memory/mcp/tools/delete_note.py +++ b/src/basic_memory/mcp/tools/delete_note.py @@ -6,11 +6,9 @@ from mcp.server.fastmcp.exceptions import ToolError from basic_memory.mcp.project_context import get_active_project -from basic_memory.mcp.tools.utils import call_delete, resolve_entity_id from basic_memory.mcp.server import mcp from basic_memory.mcp.async_client import get_client from basic_memory.telemetry import track_mcp_tool -from basic_memory.schemas import DeleteEntitiesResponse def _format_delete_error_response(project: str, error_message: str, identifier: str) -> str: @@ -208,9 +206,15 @@ async def delete_note( async with get_client() as client: active_project = await get_active_project(client, project, context) + # Import here to avoid circular import + from basic_memory.mcp.clients import KnowledgeClient + + # Use typed KnowledgeClient for API calls + knowledge_client = KnowledgeClient(client, active_project.external_id) + try: # Resolve identifier to entity ID - entity_id = await resolve_entity_id(client, active_project.external_id, identifier) + entity_id = await knowledge_client.resolve_entity(identifier) except ToolError as e: # If entity not found, return False (note doesn't exist) if "Entity not found" in str(e) or "not found" in str(e).lower(): @@ -226,10 +230,7 @@ async def delete_note( try: # Call the DELETE endpoint - response = await call_delete( - client, f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}" - ) - result = DeleteEntitiesResponse.model_validate(response.json()) + result = await knowledge_client.delete_entity(entity_id) if result.deleted: logger.info( diff --git a/src/basic_memory/mcp/tools/edit_note.py b/src/basic_memory/mcp/tools/edit_note.py index c90a5c5d..cc110169 100644 --- a/src/basic_memory/mcp/tools/edit_note.py +++ b/src/basic_memory/mcp/tools/edit_note.py @@ -8,9 +8,7 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.project_context import get_active_project, add_project_metadata from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_patch, resolve_entity_id from basic_memory.telemetry import track_mcp_tool -from basic_memory.schemas import EntityResponse def _format_error_response( @@ -236,8 +234,14 @@ async def edit_note( # Use the PATCH endpoint to edit the entity try: + # Import here to avoid circular import + from basic_memory.mcp.clients import KnowledgeClient + + # Use typed KnowledgeClient for API calls + knowledge_client = KnowledgeClient(client, active_project.external_id) + # Resolve identifier to entity ID - entity_id = await resolve_entity_id(client, active_project.external_id, identifier) + entity_id = await knowledge_client.resolve_entity(identifier) # Prepare the edit request data edit_data = { @@ -254,9 +258,7 @@ async def edit_note( edit_data["expected_replacements"] = str(expected_replacements) # Call the PATCH endpoint - url = f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}" - response = await call_patch(client, url, json=edit_data) - result = EntityResponse.model_validate(response.json()) + result = await knowledge_client.patch_entity(entity_id, edit_data) # Format summary summary = [ @@ -311,11 +313,10 @@ async def edit_note( permalink=result.permalink, observations_count=len(result.observations), relations_count=len(result.relations), - status_code=response.status_code, ) - result = "\n".join(summary) - return add_project_metadata(result, active_project.name) + summary_result = "\n".join(summary) + return add_project_metadata(summary_result, active_project.name) except Exception as e: logger.error(f"Error editing note: {e}") diff --git a/src/basic_memory/mcp/tools/list_directory.py b/src/basic_memory/mcp/tools/list_directory.py index e27f3d0f..c8019f89 100644 --- a/src/basic_memory/mcp/tools/list_directory.py +++ b/src/basic_memory/mcp/tools/list_directory.py @@ -8,7 +8,6 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.project_context import get_active_project from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_get from basic_memory.telemetry import track_mcp_tool @@ -68,26 +67,16 @@ async def list_directory( async with get_client() as client: active_project = await get_active_project(client, project, context) - # Prepare query parameters - params = { - "dir_name": dir_name, - "depth": str(depth), - } - if file_name_glob: - params["file_name_glob"] = file_name_glob - logger.debug( f"Listing directory '{dir_name}' in project {project} with depth={depth}, glob='{file_name_glob}'" ) - # Call the API endpoint - response = await call_get( - client, - f"/v2/projects/{active_project.external_id}/directory/list", - params=params, - ) + # Import here to avoid circular import + from basic_memory.mcp.clients import DirectoryClient - nodes = response.json() + # Use typed DirectoryClient for API calls + directory_client = DirectoryClient(client, active_project.external_id) + nodes = await directory_client.list(dir_name, depth=depth, file_name_glob=file_name_glob) if not nodes: filter_desc = "" diff --git a/src/basic_memory/mcp/tools/move_note.py b/src/basic_memory/mcp/tools/move_note.py index 4da70d20..66d42548 100644 --- a/src/basic_memory/mcp/tools/move_note.py +++ b/src/basic_memory/mcp/tools/move_note.py @@ -8,10 +8,7 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_get, call_put, resolve_entity_id from basic_memory.mcp.project_context import get_active_project -from basic_memory.schemas import EntityResponse -from basic_memory.schemas.project_info import ProjectList from basic_memory.telemetry import track_mcp_tool from basic_memory.utils import validate_project_path @@ -31,9 +28,12 @@ async def _detect_cross_project_move_attempt( Error message with guidance if cross-project move is detected, None otherwise """ try: - # Get list of all available projects to check against - response = await call_get(client, "/projects/projects") - project_list = ProjectList.model_validate(response.json()) + # Import here to avoid circular import + from basic_memory.mcp.clients import ProjectClient + + # Use typed ProjectClient for API calls + project_client = ProjectClient(client) + project_list = await project_client.list_projects() project_names = [p.name.lower() for p in project_list.projects] # Check if destination path contains any project names @@ -436,15 +436,19 @@ async def move_note( logger.info(f"Detected cross-project move attempt: {identifier} -> {destination_path}") return cross_project_error + # Import here to avoid circular import + from basic_memory.mcp.clients import KnowledgeClient + + # Use typed KnowledgeClient for API calls + knowledge_client = KnowledgeClient(client, active_project.external_id) + # Get the source entity information for extension validation source_ext = "md" # Default to .md if we can't determine source extension try: # Resolve identifier to entity ID - entity_id = await resolve_entity_id(client, active_project.external_id, identifier) + entity_id = await knowledge_client.resolve_entity(identifier) # Fetch source entity information to get the current file extension - url = f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}" - response = await call_get(client, url) - source_entity = EntityResponse.model_validate(response.json()) + source_entity = await knowledge_client.get_entity(entity_id) if "." in source_entity.file_path: source_ext = source_entity.file_path.split(".")[-1] except Exception as e: @@ -475,11 +479,9 @@ async def move_note( # Get the source entity to check its file extension try: # Resolve identifier to entity ID (might already be cached from above) - entity_id = await resolve_entity_id(client, active_project.external_id, identifier) + entity_id = await knowledge_client.resolve_entity(identifier) # Fetch source entity information - url = f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}" - response = await call_get(client, url) - source_entity = EntityResponse.model_validate(response.json()) + source_entity = await knowledge_client.get_entity(entity_id) # Extract file extensions source_ext = ( @@ -515,17 +517,10 @@ async def move_note( try: # Resolve identifier to entity ID for the move operation - entity_id = await resolve_entity_id(client, active_project.external_id, identifier) - - # Prepare move request (v2 API only needs destination_path) - move_data = { - "destination_path": destination_path, - } + entity_id = await knowledge_client.resolve_entity(identifier) - # Call the v2 move API endpoint (PUT method, entity_id in URL) - url = f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}/move" - response = await call_put(client, url, json=move_data) - result = EntityResponse.model_validate(response.json()) + # Call the move API using KnowledgeClient + result = await knowledge_client.move_entity(entity_id, destination_path) # Build success message result_lines = [ @@ -544,7 +539,6 @@ async def move_note( identifier=identifier, destination_path=destination_path, project=active_project.name, - status_code=response.status_code, ) return "\n".join(result_lines) diff --git a/src/basic_memory/mcp/tools/project_management.py b/src/basic_memory/mcp/tools/project_management.py index 986eb93e..79e6aedf 100644 --- a/src/basic_memory/mcp/tools/project_management.py +++ b/src/basic_memory/mcp/tools/project_management.py @@ -9,12 +9,7 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_get, call_post, call_delete -from basic_memory.schemas.project_info import ( - ProjectList, - ProjectStatusResponse, - ProjectInfoRequest, -) +from basic_memory.schemas.project_info import ProjectInfoRequest from basic_memory.telemetry import track_mcp_tool from basic_memory.utils import generate_permalink @@ -49,9 +44,12 @@ async def list_memory_projects(context: Context | None = None) -> str: # Check if server is constrained to a specific project constrained_project = os.environ.get("BASIC_MEMORY_MCP_PROJECT") - # Get projects from API - response = await call_get(client, "/projects/projects") - project_list = ProjectList.model_validate(response.json()) + # Import here to avoid circular import + from basic_memory.mcp.clients import ProjectClient + + # Use typed ProjectClient for API calls + project_client = ProjectClient(client) + project_list = await project_client.list_projects() if constrained_project: result = f"Project: {constrained_project}\n\n" @@ -109,9 +107,12 @@ async def create_memory_project( name=project_name, path=project_path, set_default=set_default ) - # Call API to create project - response = await call_post(client, "/projects/projects", json=project_request.model_dump()) - status_response = ProjectStatusResponse.model_validate(response.json()) + # Import here to avoid circular import + from basic_memory.mcp.clients import ProjectClient + + # Use typed ProjectClient for API calls + project_client = ProjectClient(client) + status_response = await project_client.create_project(project_request.model_dump()) result = f"✓ {status_response.message}\n\n" @@ -160,9 +161,14 @@ async def delete_project(project_name: str, context: Context | None = None) -> s if context: # pragma: no cover await context.info(f"Deleting project: {project_name}") + # Import here to avoid circular import + from basic_memory.mcp.clients import ProjectClient + + # Use typed ProjectClient for API calls + project_client = ProjectClient(client) + # Get project info before deletion to validate it exists - response = await call_get(client, "/projects/projects") - project_list = ProjectList.model_validate(response.json()) + project_list = await project_client.list_projects() # Find the project by permalink (derived from name). # Note: The API response uses `ProjectItem` which derives `permalink` from `name`, @@ -181,9 +187,8 @@ async def delete_project(project_name: str, context: Context | None = None) -> s f"Project '{project_name}' not found. Available projects: {', '.join(available_projects)}" ) - # Call v2 API to delete project using project external_id - response = await call_delete(client, f"/v2/projects/{target_project.external_id}") - status_response = ProjectStatusResponse.model_validate(response.json()) + # Delete project using project external_id + status_response = await project_client.delete_project(target_project.external_id) result = f"✓ {status_response.message}\n\n" diff --git a/src/basic_memory/mcp/tools/read_note.py b/src/basic_memory/mcp/tools/read_note.py index 8c76fa2e..be1d9f4f 100644 --- a/src/basic_memory/mcp/tools/read_note.py +++ b/src/basic_memory/mcp/tools/read_note.py @@ -10,7 +10,6 @@ from basic_memory.mcp.project_context import get_active_project from basic_memory.mcp.server import mcp from basic_memory.mcp.tools.search import search_notes -from basic_memory.mcp.tools.utils import call_get, resolve_entity_id from basic_memory.telemetry import track_mcp_tool from basic_memory.schemas.memory import memory_url_path from basic_memory.utils import validate_project_path @@ -105,16 +104,19 @@ async def read_note( f"Attempting to read note from Project: {active_project.name} identifier: {entity_path}" ) + # Import here to avoid circular import + from basic_memory.mcp.clients import KnowledgeClient, ResourceClient + + # Use typed clients for API calls + knowledge_client = KnowledgeClient(client, active_project.external_id) + resource_client = ResourceClient(client, active_project.external_id) + try: # Try to resolve identifier to entity ID - entity_id = await resolve_entity_id(client, active_project.external_id, entity_path) + entity_id = await knowledge_client.resolve_entity(entity_path) # Fetch content using entity ID - response = await call_get( - client, - f"/v2/projects/{active_project.external_id}/resource/{entity_id}", - params={"page": page, "page_size": page_size}, - ) + response = await resource_client.read(entity_id, page=page, page_size=page_size) # If successful, return the content if response.status_code == 200: @@ -136,14 +138,10 @@ async def read_note( if result.permalink: try: # Resolve the permalink to entity ID - entity_id = await resolve_entity_id(client, active_project.external_id, result.permalink) + entity_id = await knowledge_client.resolve_entity(result.permalink) # Fetch content using the entity ID - response = await call_get( - client, - f"/v2/projects/{active_project.external_id}/resource/{entity_id}", - params={"page": page, "page_size": page_size}, - ) + response = await resource_client.read(entity_id, page=page, page_size=page_size) if response.status_code == 200: logger.info(f"Found note by title search: {result.permalink}") diff --git a/src/basic_memory/mcp/tools/search.py b/src/basic_memory/mcp/tools/search.py index 39dd9aba..d51165a6 100644 --- a/src/basic_memory/mcp/tools/search.py +++ b/src/basic_memory/mcp/tools/search.py @@ -9,7 +9,6 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.project_context import get_active_project from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_post from basic_memory.telemetry import track_mcp_tool from basic_memory.schemas.search import SearchItemType, SearchQuery, SearchResponse @@ -365,13 +364,16 @@ async def search_notes( logger.info(f"Searching for {search_query} in project {active_project.name}") try: - response = await call_post( - client, - f"/v2/projects/{active_project.external_id}/search/", - json=search_query.model_dump(), - params={"page": page, "page_size": page_size}, + # Import here to avoid circular import (tools → clients → utils → tools) + from basic_memory.mcp.clients import SearchClient + + # Use typed SearchClient for API calls + search_client = SearchClient(client, active_project.external_id) + result = await search_client.search( + search_query.model_dump(), + page=page, + page_size=page_size, ) - result = SearchResponse.model_validate(response.json()) # Check if we got no results and provide helpful guidance if not result.results: diff --git a/src/basic_memory/mcp/tools/write_note.py b/src/basic_memory/mcp/tools/write_note.py index f0864844..f6060e46 100644 --- a/src/basic_memory/mcp/tools/write_note.py +++ b/src/basic_memory/mcp/tools/write_note.py @@ -7,9 +7,7 @@ from basic_memory.mcp.async_client import get_client from basic_memory.mcp.project_context import get_active_project, add_project_metadata from basic_memory.mcp.server import mcp -from basic_memory.mcp.tools.utils import call_put, call_post, resolve_entity_id from basic_memory.telemetry import track_mcp_tool -from basic_memory.schemas import EntityResponse from fastmcp import Context from basic_memory.schemas.base import Entity from basic_memory.utils import parse_tags, validate_project_path @@ -153,13 +151,17 @@ async def write_note( entity_metadata=metadata, ) + # Import here to avoid circular import + from basic_memory.mcp.clients import KnowledgeClient + + # Use typed KnowledgeClient for API calls + knowledge_client = KnowledgeClient(client, active_project.external_id) + # Try to create the entity first (optimistic create) logger.debug(f"Attempting to create entity permalink={entity.permalink}") action = "Created" # Default to created try: - url = f"/v2/projects/{active_project.external_id}/knowledge/entities" - response = await call_post(client, url, json=entity.model_dump()) - result = EntityResponse.model_validate(response.json()) + result = await knowledge_client.create_entity(entity.model_dump()) action = "Created" except Exception as e: # If creation failed due to conflict (already exists), try to update @@ -172,10 +174,8 @@ async def write_note( try: if not entity.permalink: raise ValueError("Entity permalink is required for updates") # pragma: no cover - entity_id = await resolve_entity_id(client, active_project.external_id, entity.permalink) - url = f"/v2/projects/{active_project.external_id}/knowledge/entities/{entity_id}" - response = await call_put(client, url, json=entity.model_dump()) - result = EntityResponse.model_validate(response.json()) + entity_id = await knowledge_client.resolve_entity(entity.permalink) + result = await knowledge_client.update_entity(entity_id, entity.model_dump()) action = "Updated" except Exception as update_error: # pragma: no cover # Re-raise the original error if update also fails @@ -224,7 +224,7 @@ async def write_note( # Log the response with structured data logger.info( - f"MCP tool response: tool=write_note project={active_project.name} action={action} permalink={result.permalink} observations_count={len(result.observations)} relations_count={len(result.relations)} resolved_relations={resolved} unresolved_relations={unresolved} status_code={response.status_code}" + f"MCP tool response: tool=write_note project={active_project.name} action={action} permalink={result.permalink} observations_count={len(result.observations)} relations_count={len(result.relations)} resolved_relations={resolved} unresolved_relations={unresolved}" ) - result = "\n".join(summary) - return add_project_metadata(result, active_project.name) + summary_result = "\n".join(summary) + return add_project_metadata(summary_result, active_project.name) diff --git a/src/basic_memory/project_resolver.py b/src/basic_memory/project_resolver.py new file mode 100644 index 00000000..20b129f4 --- /dev/null +++ b/src/basic_memory/project_resolver.py @@ -0,0 +1,222 @@ +"""Unified project resolution across MCP, API, and CLI. + +This module provides a single canonical implementation of project resolution +logic, eliminating duplicated decision trees across the codebase. + +The resolution follows a three-tier hierarchy: +1. Constrained mode: BASIC_MEMORY_MCP_PROJECT env var (highest priority) +2. Explicit parameter: Project passed directly to operation +3. Default project: Used when default_project_mode=true (lowest priority) + +In cloud mode, project is required unless discovery mode is explicitly allowed. +""" + +import os +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional + +from loguru import logger + + +class ResolutionMode(Enum): + """How the project was resolved.""" + + CLOUD_EXPLICIT = auto() # Explicit project in cloud mode + CLOUD_DISCOVERY = auto() # Discovery mode allowed in cloud (no project) + ENV_CONSTRAINT = auto() # BASIC_MEMORY_MCP_PROJECT env var + EXPLICIT = auto() # Explicit project parameter + DEFAULT = auto() # default_project with default_project_mode=true + NONE = auto() # No resolution possible + + +@dataclass(frozen=True) +class ResolvedProject: + """Result of project resolution. + + Attributes: + project: The resolved project name, or None if not resolved + mode: How the project was resolved + reason: Human-readable explanation of resolution + """ + + project: Optional[str] + mode: ResolutionMode + reason: str + + @property + def is_resolved(self) -> bool: + """Whether a project was successfully resolved.""" + return self.project is not None + + @property + def is_discovery_mode(self) -> bool: + """Whether we're in discovery mode (no specific project).""" + return self.mode == ResolutionMode.CLOUD_DISCOVERY or ( + self.mode == ResolutionMode.NONE and self.project is None + ) + + +@dataclass +class ProjectResolver: + """Unified project resolution logic. + + Resolves the effective project given requested project, environment + constraints, and configuration settings. + + This is the single canonical implementation of project resolution, + used by MCP tools, API routes, and CLI commands. + + Args: + cloud_mode: Whether running in cloud mode (project required) + default_project_mode: Whether to use default project when not specified + default_project: The default project name + constrained_project: Optional env-constrained project override + (typically from BASIC_MEMORY_MCP_PROJECT) + """ + + cloud_mode: bool = False + default_project_mode: bool = False + default_project: Optional[str] = None + constrained_project: Optional[str] = None + + @classmethod + def from_env( + cls, + cloud_mode: bool = False, + default_project_mode: bool = False, + default_project: Optional[str] = None, + ) -> "ProjectResolver": + """Create resolver with constrained_project from environment. + + Args: + cloud_mode: Whether running in cloud mode + default_project_mode: Whether to use default project when not specified + default_project: The default project name + + Returns: + ProjectResolver configured with current environment + """ + constrained = os.environ.get("BASIC_MEMORY_MCP_PROJECT") + return cls( + cloud_mode=cloud_mode, + default_project_mode=default_project_mode, + default_project=default_project, + constrained_project=constrained, + ) + + def resolve( + self, + project: Optional[str] = None, + allow_discovery: bool = False, + ) -> ResolvedProject: + """Resolve project using the three-tier hierarchy. + + Resolution order: + 1. Cloud mode check (project required unless discovery allowed) + 2. Constrained project from env var (highest priority in local mode) + 3. Explicit project parameter + 4. Default project if default_project_mode=true + + Args: + project: Optional explicit project parameter + allow_discovery: If True, allows returning None in cloud mode + for discovery operations (e.g., recent_activity across projects) + + Returns: + ResolvedProject with project name, resolution mode, and reason + + Raises: + ValueError: If in cloud mode and no project specified (unless discovery allowed) + """ + # --- Cloud Mode Handling --- + # In cloud mode, project is required unless discovery is explicitly allowed + if self.cloud_mode: + if project: + logger.debug(f"Cloud mode: using explicit project '{project}'") + return ResolvedProject( + project=project, + mode=ResolutionMode.CLOUD_EXPLICIT, + reason=f"Explicit project in cloud mode: {project}", + ) + elif allow_discovery: + logger.debug("Cloud mode: discovery mode allowed, no project required") + return ResolvedProject( + project=None, + mode=ResolutionMode.CLOUD_DISCOVERY, + reason="Discovery mode enabled in cloud", + ) + else: + raise ValueError("No project specified. Project is required for cloud mode.") + + # --- Local Mode: Three-Tier Hierarchy --- + + # Priority 1: CLI constraint overrides everything + if self.constrained_project: + logger.debug(f"Using CLI constrained project: {self.constrained_project}") + return ResolvedProject( + project=self.constrained_project, + mode=ResolutionMode.ENV_CONSTRAINT, + reason=f"Environment constraint: BASIC_MEMORY_MCP_PROJECT={self.constrained_project}", + ) + + # Priority 2: Explicit project parameter + if project: + logger.debug(f"Using explicit project parameter: {project}") + return ResolvedProject( + project=project, + mode=ResolutionMode.EXPLICIT, + reason=f"Explicit parameter: {project}", + ) + + # Priority 3: Default project mode + if self.default_project_mode and self.default_project: + logger.debug(f"Using default project from config: {self.default_project}") + return ResolvedProject( + project=self.default_project, + mode=ResolutionMode.DEFAULT, + reason=f"Default project mode: {self.default_project}", + ) + + # No resolution possible + logger.debug("No project resolution possible") + return ResolvedProject( + project=None, + mode=ResolutionMode.NONE, + reason="No project specified and default_project_mode is disabled", + ) + + def require_project( + self, + project: Optional[str] = None, + error_message: Optional[str] = None, + ) -> ResolvedProject: + """Resolve project, raising an error if not resolved. + + Convenience method for operations that require a project. + + Args: + project: Optional explicit project parameter + error_message: Custom error message if project not resolved + + Returns: + ResolvedProject (always with a non-None project) + + Raises: + ValueError: If project could not be resolved + """ + result = self.resolve(project, allow_discovery=False) + if not result.is_resolved: + msg = error_message or ( + "No project specified. Either set 'default_project_mode=true' in config, " + "or provide a 'project' argument." + ) + raise ValueError(msg) + return result + + +__all__ = [ + "ProjectResolver", + "ResolvedProject", + "ResolutionMode", +] diff --git a/src/basic_memory/repository/entity_repository.py b/src/basic_memory/repository/entity_repository.py index 6b9fc4f2..323317ca 100644 --- a/src/basic_memory/repository/entity_repository.py +++ b/src/basic_memory/repository/entity_repository.py @@ -33,7 +33,7 @@ def __init__(self, session_maker: async_sessionmaker[AsyncSession], project_id: """ super().__init__(session_maker, Entity, project_id=project_id) - async def get_by_id(self, entity_id: int) -> Optional[Entity]: + async def get_by_id(self, entity_id: int) -> Optional[Entity]: # pragma: no cover """Get entity by numeric ID. Args: diff --git a/src/basic_memory/repository/search_repository.py b/src/basic_memory/repository/search_repository.py index 0204cdc9..4be70aae 100644 --- a/src/basic_memory/repository/search_repository.py +++ b/src/basic_memory/repository/search_repository.py @@ -68,20 +68,27 @@ async def execute_query(self, query, params: dict) -> Result: def create_search_repository( - session_maker: async_sessionmaker[AsyncSession], project_id: int + session_maker: async_sessionmaker[AsyncSession], + project_id: int, + database_backend: Optional[DatabaseBackend] = None, ) -> SearchRepository: """Factory function to create the appropriate search repository based on database backend. Args: session_maker: SQLAlchemy async session maker project_id: Project ID for the repository + database_backend: Optional explicit backend. If not provided, reads from ConfigManager. + Prefer passing explicitly from composition roots. Returns: SearchRepository: Backend-appropriate search repository instance """ - config = ConfigManager().config + # Prefer explicit parameter; fall back to ConfigManager for backwards compatibility + if database_backend is None: + config = ConfigManager().config + database_backend = config.database_backend - if config.database_backend == DatabaseBackend.POSTGRES: # pragma: no cover + if database_backend == DatabaseBackend.POSTGRES: # pragma: no cover return PostgresSearchRepository(session_maker, project_id=project_id) # pragma: no cover else: return SQLiteSearchRepository(session_maker, project_id=project_id) diff --git a/src/basic_memory/runtime.py b/src/basic_memory/runtime.py new file mode 100644 index 00000000..852481b8 --- /dev/null +++ b/src/basic_memory/runtime.py @@ -0,0 +1,61 @@ +"""Runtime mode resolution for Basic Memory. + +This module centralizes runtime mode detection, ensuring cloud/local/test +determination happens in one place rather than scattered across modules. + +Composition roots (containers) read ConfigManager and use this module +to resolve the runtime mode, then pass the result downstream. +""" + +from enum import Enum, auto + + +class RuntimeMode(Enum): + """Runtime modes for Basic Memory.""" + + LOCAL = auto() # Local standalone mode (default) + CLOUD = auto() # Cloud mode with remote sync + TEST = auto() # Test environment + + @property + def is_cloud(self) -> bool: + return self == RuntimeMode.CLOUD + + @property + def is_local(self) -> bool: + return self == RuntimeMode.LOCAL + + @property + def is_test(self) -> bool: + return self == RuntimeMode.TEST + + +def resolve_runtime_mode( + cloud_mode_enabled: bool, + is_test_env: bool, +) -> RuntimeMode: + """Resolve the runtime mode from configuration flags. + + This is the single source of truth for mode resolution. + Composition roots call this with config values they've read. + + Args: + cloud_mode_enabled: Whether cloud mode is enabled in config + is_test_env: Whether running in test environment + + Returns: + The resolved RuntimeMode + """ + # Trigger: test environment is detected + # Why: tests need special handling (no file sync, isolated DB) + # Outcome: returns TEST mode, skipping cloud mode check + if is_test_env: + return RuntimeMode.TEST + + # Trigger: cloud mode is enabled in config + # Why: cloud mode changes auth, sync, and API behavior + # Outcome: returns CLOUD mode for remote-first behavior + if cloud_mode_enabled: + return RuntimeMode.CLOUD + + return RuntimeMode.LOCAL \ No newline at end of file diff --git a/src/basic_memory/sync/__init__.py b/src/basic_memory/sync/__init__.py index 4a856168..286b0418 100644 --- a/src/basic_memory/sync/__init__.py +++ b/src/basic_memory/sync/__init__.py @@ -1,6 +1,7 @@ """Basic Memory sync services.""" +from .coordinator import SyncCoordinator, SyncStatus from .sync_service import SyncService from .watch_service import WatchService -__all__ = ["SyncService", "WatchService"] +__all__ = ["SyncService", "WatchService", "SyncCoordinator", "SyncStatus"] diff --git a/src/basic_memory/sync/coordinator.py b/src/basic_memory/sync/coordinator.py new file mode 100644 index 00000000..e9dde8e3 --- /dev/null +++ b/src/basic_memory/sync/coordinator.py @@ -0,0 +1,160 @@ +"""SyncCoordinator - centralized sync/watch lifecycle management. + +This module provides a single coordinator that manages the lifecycle of +file synchronization and watch services across all entry points (API, MCP, CLI). + +The coordinator handles: +- Starting/stopping watch service +- Scheduling background sync +- Reporting status +- Clean shutdown behavior +""" + +import asyncio +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Optional + +from loguru import logger + +from basic_memory.config import BasicMemoryConfig + + +class SyncStatus(Enum): + """Status of the sync coordinator.""" + + NOT_STARTED = auto() + STARTING = auto() + RUNNING = auto() + STOPPING = auto() + STOPPED = auto() + ERROR = auto() + + +@dataclass +class SyncCoordinator: + """Centralized coordinator for sync/watch lifecycle. + + Manages the lifecycle of file synchronization services, providing: + - Unified start/stop interface + - Status tracking + - Clean shutdown with proper task cancellation + + Args: + config: BasicMemoryConfig with sync settings + should_sync: Whether sync should be enabled (from container decision) + skip_reason: Human-readable reason if sync is skipped + + Usage: + coordinator = SyncCoordinator(config=config, should_sync=True) + await coordinator.start() + # ... application runs ... + await coordinator.stop() + """ + + config: BasicMemoryConfig + should_sync: bool = True + skip_reason: Optional[str] = None + + # Internal state (not constructor args) + _status: SyncStatus = field(default=SyncStatus.NOT_STARTED, init=False) + _sync_task: Optional[asyncio.Task] = field(default=None, init=False) + + @property + def status(self) -> SyncStatus: + """Current status of the coordinator.""" + return self._status + + @property + def is_running(self) -> bool: + """Whether sync is currently running.""" + return self._status == SyncStatus.RUNNING + + async def start(self) -> None: + """Start the sync/watch service if enabled. + + This is a non-blocking call that starts the sync task in the background. + Use stop() to cleanly shut down. + """ + if not self.should_sync: + if self.skip_reason: + logger.info(f"{self.skip_reason} - skipping local file sync") + self._status = SyncStatus.STOPPED + return + + if self._status in (SyncStatus.RUNNING, SyncStatus.STARTING): + logger.warning("Sync coordinator already running or starting") + return + + self._status = SyncStatus.STARTING + logger.info("Starting file sync in background") + + try: + # Deferred import to avoid circular dependency + from basic_memory.services.initialization import initialize_file_sync + + async def _file_sync_runner() -> None: # pragma: no cover + """Run the file sync service.""" + try: + await initialize_file_sync(self.config) + except asyncio.CancelledError: + logger.debug("File sync cancelled") + raise + except Exception as e: + logger.error(f"Error in file sync: {e}") + self._status = SyncStatus.ERROR + raise + + self._sync_task = asyncio.create_task(_file_sync_runner()) + self._status = SyncStatus.RUNNING + logger.info("Sync coordinator started successfully") + + except Exception as e: # pragma: no cover + logger.error(f"Failed to start sync coordinator: {e}") + self._status = SyncStatus.ERROR + raise + + async def stop(self) -> None: + """Stop the sync/watch service cleanly. + + Cancels the background task and waits for it to complete. + Safe to call even if not running. + """ + if self._status in (SyncStatus.NOT_STARTED, SyncStatus.STOPPED): + return + + if self._sync_task is None: # pragma: no cover + self._status = SyncStatus.STOPPED + return + + self._status = SyncStatus.STOPPING + logger.info("Stopping sync coordinator...") + + self._sync_task.cancel() + try: + await self._sync_task + except asyncio.CancelledError: + logger.info("File sync task cancelled successfully") + + self._sync_task = None + self._status = SyncStatus.STOPPED + logger.info("Sync coordinator stopped") + + def get_status_info(self) -> dict: + """Get status information for reporting. + + Returns: + Dictionary with status details for diagnostics + """ + return { + "status": self._status.name, + "should_sync": self.should_sync, + "skip_reason": self.skip_reason, + "has_task": self._sync_task is not None, + } + + +__all__ = [ + "SyncCoordinator", + "SyncStatus", +] diff --git a/test-int/mcp/test_lifespan_shutdown_sync_task_cancellation_integration.py b/test-int/mcp/test_lifespan_shutdown_sync_task_cancellation_integration.py index 489303b6..3abfdc6b 100644 --- a/test-int/mcp/test_lifespan_shutdown_sync_task_cancellation_integration.py +++ b/test-int/mcp/test_lifespan_shutdown_sync_task_cancellation_integration.py @@ -22,7 +22,7 @@ def test_lifespan_shutdown_awaits_sync_task_cancellation(app, monkeypatch): - In the buggy version, shutdown proceeded directly to db.shutdown_db() immediately after calling cancel(), so at *entry* to shutdown_db the task is still not done. - - In the fixed version, lifespan does `await sync_task` before shutdown_db, + - In the fixed version, SyncCoordinator.stop() awaits the task before returning, so by the time shutdown_db is called, the task is done (cancelled). """ @@ -34,29 +34,40 @@ def test_lifespan_shutdown_awaits_sync_task_cancellation(app, monkeypatch): import importlib api_app_module = importlib.import_module("basic_memory.api.app") + container_module = importlib.import_module("basic_memory.api.container") + init_module = importlib.import_module("basic_memory.services.initialization") # Keep startup cheap: we don't need real DB init for this ordering test. async def _noop_initialize_app(_app_config): return None - async def _fake_get_or_create_db(*_args, **_kwargs): - return object(), object() - monkeypatch.setattr(api_app_module, "initialize_app", _noop_initialize_app) - monkeypatch.setattr(api_app_module.db, "get_or_create_db", _fake_get_or_create_db) + + # Patch the container's init_database to return fake objects + async def _fake_init_database(self): + self.engine = object() + self.session_maker = object() + return self.engine, self.session_maker + + monkeypatch.setattr(container_module.ApiContainer, "init_database", _fake_init_database) # Make the sync task long-lived so it must be cancelled on shutdown. + # Patch at the source module where SyncCoordinator imports it. async def _fake_initialize_file_sync(_app_config): await asyncio.Event().wait() - monkeypatch.setattr(api_app_module, "initialize_file_sync", _fake_initialize_file_sync) + monkeypatch.setattr(init_module, "initialize_file_sync", _fake_initialize_file_sync) # Assert ordering: shutdown_db must be called only after the sync_task is done. - async def _assert_sync_task_done_before_db_shutdown(): - assert api_app_module.app.state.sync_task is not None - assert api_app_module.app.state.sync_task.done() - - monkeypatch.setattr(api_app_module.db, "shutdown_db", _assert_sync_task_done_before_db_shutdown) + # SyncCoordinator stores the task in _sync_task attribute. + async def _assert_sync_task_done_before_db_shutdown(self): + sync_coordinator = api_app_module.app.state.sync_coordinator + assert sync_coordinator._sync_task is not None + assert sync_coordinator._sync_task.done() + + monkeypatch.setattr( + container_module.ApiContainer, "shutdown_database", _assert_sync_task_done_before_db_shutdown + ) async def _run_client_once(): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: diff --git a/tests/api/test_api_container.py b/tests/api/test_api_container.py new file mode 100644 index 00000000..fb13d6c5 --- /dev/null +++ b/tests/api/test_api_container.py @@ -0,0 +1,62 @@ +"""Tests for API container composition root.""" + +import pytest + +from basic_memory.api.container import ( + ApiContainer, + get_container, + set_container, +) +from basic_memory.runtime import RuntimeMode + + +class TestApiContainer: + """Tests for ApiContainer.""" + + def test_create_from_config(self, app_config): + """Container can be created from config manager.""" + container = ApiContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.config == app_config + assert container.mode == RuntimeMode.LOCAL + + def test_should_sync_files_when_enabled_and_not_test(self, app_config): + """Sync should be enabled when config says so and not in test mode.""" + app_config.sync_changes = True + container = ApiContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.should_sync_files is True + + def test_should_not_sync_files_when_disabled(self, app_config): + """Sync should be disabled when config says so.""" + app_config.sync_changes = False + container = ApiContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.should_sync_files is False + + def test_should_not_sync_files_in_test_mode(self, app_config): + """Sync should be disabled in test mode regardless of config.""" + app_config.sync_changes = True + container = ApiContainer(config=app_config, mode=RuntimeMode.TEST) + assert container.should_sync_files is False + + +class TestContainerAccessors: + """Tests for container get/set functions.""" + + def test_get_container_raises_when_not_set(self, monkeypatch): + """get_container raises RuntimeError when container not initialized.""" + # Clear any existing container + import basic_memory.api.container as container_module + + monkeypatch.setattr(container_module, "_container", None) + + with pytest.raises(RuntimeError, match="API container not initialized"): + get_container() + + def test_set_and_get_container(self, app_config, monkeypatch): + """set_container allows get_container to return the container.""" + import basic_memory.api.container as container_module + + container = ApiContainer(config=app_config, mode=RuntimeMode.LOCAL) + monkeypatch.setattr(container_module, "_container", None) + + set_container(container) + assert get_container() is container diff --git a/tests/cli/test_cli_container.py b/tests/cli/test_cli_container.py new file mode 100644 index 00000000..149fb771 --- /dev/null +++ b/tests/cli/test_cli_container.py @@ -0,0 +1,96 @@ +"""Tests for CLI container composition root.""" + +import pytest + +from basic_memory.cli.container import ( + CliContainer, + get_container, + set_container, + get_or_create_container, +) +from basic_memory.runtime import RuntimeMode + + +class TestCliContainer: + """Tests for CliContainer.""" + + def test_create_from_config(self, app_config): + """Container can be created from config.""" + container = CliContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.config == app_config + assert container.mode == RuntimeMode.LOCAL + + def test_is_cloud_mode_when_cloud(self, app_config): + """is_cloud_mode returns True in cloud mode.""" + container = CliContainer(config=app_config, mode=RuntimeMode.CLOUD) + assert container.is_cloud_mode is True + + def test_is_cloud_mode_when_local(self, app_config): + """is_cloud_mode returns False in local mode.""" + container = CliContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.is_cloud_mode is False + + def test_is_cloud_mode_when_test(self, app_config): + """is_cloud_mode returns False in test mode.""" + container = CliContainer(config=app_config, mode=RuntimeMode.TEST) + assert container.is_cloud_mode is False + + +class TestContainerAccessors: + """Tests for container get/set functions.""" + + def test_get_container_raises_when_not_set(self, monkeypatch): + """get_container raises RuntimeError when container not initialized.""" + import basic_memory.cli.container as container_module + + monkeypatch.setattr(container_module, "_container", None) + + with pytest.raises(RuntimeError, match="CLI container not initialized"): + get_container() + + def test_set_and_get_container(self, app_config, monkeypatch): + """set_container allows get_container to return the container.""" + import basic_memory.cli.container as container_module + + container = CliContainer(config=app_config, mode=RuntimeMode.LOCAL) + monkeypatch.setattr(container_module, "_container", None) + + set_container(container) + assert get_container() is container + + +class TestGetOrCreateContainer: + """Tests for get_or_create_container - unique to CLI container.""" + + def test_creates_new_when_none_exists(self, monkeypatch): + """get_or_create_container creates a new container when none exists.""" + import basic_memory.cli.container as container_module + + monkeypatch.setattr(container_module, "_container", None) + + container = get_or_create_container() + assert container is not None + assert isinstance(container, CliContainer) + + def test_returns_existing_when_set(self, app_config, monkeypatch): + """get_or_create_container returns existing container if already set.""" + import basic_memory.cli.container as container_module + + existing = CliContainer(config=app_config, mode=RuntimeMode.LOCAL) + monkeypatch.setattr(container_module, "_container", existing) + + result = get_or_create_container() + assert result is existing + + def test_sets_module_level_container(self, monkeypatch): + """get_or_create_container sets the module-level container.""" + import basic_memory.cli.container as container_module + + monkeypatch.setattr(container_module, "_container", None) + + container = get_or_create_container() + + # Verify it was set at module level + assert container_module._container is container + # Verify get_container now works + assert get_container() is container diff --git a/tests/mcp/clients/__init__.py b/tests/mcp/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mcp/clients/test_clients.py b/tests/mcp/clients/test_clients.py new file mode 100644 index 00000000..333f9fa6 --- /dev/null +++ b/tests/mcp/clients/test_clients.py @@ -0,0 +1,312 @@ +"""Tests for typed API clients.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from basic_memory.mcp.clients import ( + KnowledgeClient, + SearchClient, + MemoryClient, + DirectoryClient, + ResourceClient, + ProjectClient, +) + + +class TestKnowledgeClient: + """Tests for KnowledgeClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = KnowledgeClient(mock_http, "project-123") + assert client.http_client is mock_http + assert client.project_id == "project-123" + assert client._base_path == "/v2/projects/project-123/knowledge" + + @pytest.mark.asyncio + async def test_create_entity(self, monkeypatch): + """Test create_entity calls correct endpoint.""" + from basic_memory.mcp.clients import knowledge as knowledge_mod + + mock_response = MagicMock() + mock_response.json.return_value = { + "permalink": "test", + "title": "Test", + "file_path": "test.md", + "entity_type": "note", + "content_type": "text/markdown", + "observations": [], + "relations": [], + "created_at": "2024-01-01T00:00:00", + "updated_at": "2024-01-01T00:00:00", + } + + async def mock_call_post(client, url, **kwargs): + assert "/v2/projects/proj-123/knowledge/entities" in url + return mock_response + + monkeypatch.setattr(knowledge_mod, "call_post", mock_call_post) + + mock_http = MagicMock() + client = KnowledgeClient(mock_http, "proj-123") + result = await client.create_entity({"title": "Test"}) + assert result.title == "Test" + + @pytest.mark.asyncio + async def test_resolve_entity(self, monkeypatch): + """Test resolve_entity returns external_id.""" + from basic_memory.mcp.clients import knowledge as knowledge_mod + + mock_response = MagicMock() + mock_response.json.return_value = {"external_id": "entity-uuid-123"} + + async def mock_call_post(client, url, **kwargs): + assert "/v2/projects/proj-123/knowledge/resolve" in url + return mock_response + + monkeypatch.setattr(knowledge_mod, "call_post", mock_call_post) + + mock_http = MagicMock() + client = KnowledgeClient(mock_http, "proj-123") + result = await client.resolve_entity("my-note") + assert result == "entity-uuid-123" + + +class TestSearchClient: + """Tests for SearchClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = SearchClient(mock_http, "project-123") + assert client.http_client is mock_http + assert client.project_id == "project-123" + assert client._base_path == "/v2/projects/project-123/search" + + @pytest.mark.asyncio + async def test_search(self, monkeypatch): + """Test search calls correct endpoint.""" + from basic_memory.mcp.clients import search as search_mod + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [], + "current_page": 1, + "page_size": 10, + } + + async def mock_call_post(client, url, **kwargs): + assert "/v2/projects/proj-123/search/" in url + assert kwargs.get("params") == {"page": 1, "page_size": 10} + return mock_response + + monkeypatch.setattr(search_mod, "call_post", mock_call_post) + + mock_http = MagicMock() + client = SearchClient(mock_http, "proj-123") + result = await client.search({"text": "query"}, page=1, page_size=10) + assert result.results == [] + assert result.current_page == 1 + + +class TestMemoryClient: + """Tests for MemoryClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = MemoryClient(mock_http, "project-123") + assert client.http_client is mock_http + assert client.project_id == "project-123" + assert client._base_path == "/v2/projects/project-123/memory" + + @pytest.mark.asyncio + async def test_build_context(self, monkeypatch): + """Test build_context calls correct endpoint.""" + from basic_memory.mcp.clients import memory as memory_mod + from datetime import datetime + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [], + "metadata": { + "depth": 1, + "generated_at": datetime.now().isoformat(), + }, + } + + async def mock_call_get(client, url, **kwargs): + assert "/v2/projects/proj-123/memory/specs/search" in url + return mock_response + + monkeypatch.setattr(memory_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = MemoryClient(mock_http, "proj-123") + result = await client.build_context("specs/search") + assert result.results == [] + + + @pytest.mark.asyncio + async def test_recent(self, monkeypatch): + """Test recent calls correct endpoint.""" + from basic_memory.mcp.clients import memory as memory_mod + from datetime import datetime + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [], + "metadata": { + "depth": 2, + "generated_at": datetime.now().isoformat(), + }, + } + + async def mock_call_get(client, url, **kwargs): + assert "/v2/projects/proj-123/memory/recent" in url + params = kwargs.get("params", {}) + assert params.get("timeframe") == "7d" + assert params.get("depth") == 2 + return mock_response + + monkeypatch.setattr(memory_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = MemoryClient(mock_http, "proj-123") + result = await client.recent(timeframe="7d", depth=2) + assert result.results == [] + assert result.metadata.depth == 2 + + @pytest.mark.asyncio + async def test_recent_with_types(self, monkeypatch): + """Test recent with types filter.""" + from basic_memory.mcp.clients import memory as memory_mod + from datetime import datetime + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [], + "metadata": { + "depth": 1, + "generated_at": datetime.now().isoformat(), + }, + } + + async def mock_call_get(client, url, **kwargs): + assert "/v2/projects/proj-123/memory/recent" in url + params = kwargs.get("params", {}) + assert params.get("type") == "note,spec" + return mock_response + + monkeypatch.setattr(memory_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = MemoryClient(mock_http, "proj-123") + result = await client.recent(types=["note", "spec"]) + assert result.results == [] + + +class TestDirectoryClient: + """Tests for DirectoryClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = DirectoryClient(mock_http, "project-123") + assert client.http_client is mock_http + assert client.project_id == "project-123" + assert client._base_path == "/v2/projects/project-123/directory" + + @pytest.mark.asyncio + async def test_list(self, monkeypatch): + """Test list calls correct endpoint.""" + from basic_memory.mcp.clients import directory as directory_mod + + mock_response = MagicMock() + mock_response.json.return_value = [{"name": "folder", "type": "directory"}] + + async def mock_call_get(client, url, **kwargs): + assert "/v2/projects/proj-123/directory/list" in url + return mock_response + + monkeypatch.setattr(directory_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = DirectoryClient(mock_http, "proj-123") + result = await client.list("/") + assert len(result) == 1 + assert result[0]["name"] == "folder" + + +class TestResourceClient: + """Tests for ResourceClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = ResourceClient(mock_http, "project-123") + assert client.http_client is mock_http + assert client.project_id == "project-123" + assert client._base_path == "/v2/projects/project-123/resource" + + @pytest.mark.asyncio + async def test_read(self, monkeypatch): + """Test read calls correct endpoint.""" + from basic_memory.mcp.clients import resource as resource_mod + + mock_response = MagicMock() + mock_response.text = "# Note content" + + async def mock_call_get(client, url, **kwargs): + assert "/v2/projects/proj-123/resource/entity-123" in url + return mock_response + + monkeypatch.setattr(resource_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = ResourceClient(mock_http, "proj-123") + result = await client.read("entity-123") + assert result.text == "# Note content" + + +class TestProjectClient: + """Tests for ProjectClient.""" + + def test_init(self): + """Test client initialization.""" + mock_http = MagicMock() + client = ProjectClient(mock_http) + assert client.http_client is mock_http + + @pytest.mark.asyncio + async def test_list_projects(self, monkeypatch): + """Test list_projects calls correct endpoint.""" + from basic_memory.mcp.clients import project as project_mod + + mock_response = MagicMock() + mock_response.json.return_value = { + "projects": [ + { + "id": 1, + "external_id": "uuid-123", + "name": "test-project", + "path": "/path/to/project", + "is_default": True, + } + ], + "default_project": "test-project", + } + + async def mock_call_get(client, url, **kwargs): + assert "/projects/projects" in url + return mock_response + + monkeypatch.setattr(project_mod, "call_get", mock_call_get) + + mock_http = MagicMock() + client = ProjectClient(mock_http) + result = await client.list_projects() + assert len(result.projects) == 1 + assert result.projects[0].name == "test-project" + assert result.default_project == "test-project" diff --git a/tests/mcp/test_mcp_container.py b/tests/mcp/test_mcp_container.py new file mode 100644 index 00000000..eb12ef1c --- /dev/null +++ b/tests/mcp/test_mcp_container.py @@ -0,0 +1,93 @@ +"""Tests for MCP container composition root.""" + +import pytest + +from basic_memory.mcp.container import ( + McpContainer, + get_container, + set_container, +) +from basic_memory.runtime import RuntimeMode + + +class TestMcpContainer: + """Tests for McpContainer.""" + + def test_create_from_config(self, app_config): + """Container can be created from config manager.""" + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.config == app_config + assert container.mode == RuntimeMode.LOCAL + + def test_should_sync_files_when_enabled_local_mode(self, app_config): + """Sync should be enabled in local mode when config says so.""" + app_config.sync_changes = True + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.should_sync_files is True + + def test_should_not_sync_files_when_disabled(self, app_config): + """Sync should be disabled when config says so.""" + app_config.sync_changes = False + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.should_sync_files is False + + def test_should_not_sync_files_in_test_mode(self, app_config): + """Sync should be disabled in test mode regardless of config.""" + app_config.sync_changes = True + container = McpContainer(config=app_config, mode=RuntimeMode.TEST) + assert container.should_sync_files is False + + def test_should_not_sync_files_in_cloud_mode(self, app_config): + """Sync should be disabled in cloud mode (cloud handles sync differently).""" + app_config.sync_changes = True + container = McpContainer(config=app_config, mode=RuntimeMode.CLOUD) + assert container.should_sync_files is False + + +class TestSyncSkipReason: + """Tests for sync_skip_reason property.""" + + def test_skip_reason_in_test_mode(self, app_config): + """Returns test message when in test mode.""" + container = McpContainer(config=app_config, mode=RuntimeMode.TEST) + assert container.sync_skip_reason == "Test environment detected" + + def test_skip_reason_in_cloud_mode(self, app_config): + """Returns cloud message when in cloud mode.""" + container = McpContainer(config=app_config, mode=RuntimeMode.CLOUD) + assert container.sync_skip_reason == "Cloud mode enabled" + + def test_skip_reason_when_sync_disabled(self, app_config): + """Returns disabled message when sync is disabled.""" + app_config.sync_changes = False + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.sync_skip_reason == "Sync changes disabled" + + def test_no_skip_reason_when_should_sync(self, app_config): + """Returns None when sync should run.""" + app_config.sync_changes = True + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + assert container.sync_skip_reason is None + + +class TestContainerAccessors: + """Tests for container get/set functions.""" + + def test_get_container_raises_when_not_set(self, monkeypatch): + """get_container raises RuntimeError when container not initialized.""" + import basic_memory.mcp.container as container_module + + monkeypatch.setattr(container_module, "_container", None) + + with pytest.raises(RuntimeError, match="MCP container not initialized"): + get_container() + + def test_set_and_get_container(self, app_config, monkeypatch): + """set_container allows get_container to return the container.""" + import basic_memory.mcp.container as container_module + + container = McpContainer(config=app_config, mode=RuntimeMode.LOCAL) + monkeypatch.setattr(container_module, "_container", None) + + set_container(container) + assert get_container() is container diff --git a/tests/mcp/test_tool_move_note.py b/tests/mcp/test_tool_move_note.py index 489cd742..c4b795fe 100644 --- a/tests/mcp/test_tool_move_note.py +++ b/tests/mcp/test_tool_move_note.py @@ -12,12 +12,19 @@ async def test_detect_cross_project_move_attempt_is_defensive_on_api_error(monke """Cross-project detection should fail open (return None) if the projects API errors.""" import importlib - move_note_module = importlib.import_module("basic_memory.mcp.tools.move_note") + clients_mod = importlib.import_module("basic_memory.mcp.clients") + + # Mock ProjectClient to raise an exception on list_projects + class MockProjectClient: + def __init__(self, *args, **kwargs): + pass - async def boom(*args, **kwargs): - raise RuntimeError("boom") + async def list_projects(self, *args, **kwargs): + raise RuntimeError("boom") - monkeypatch.setattr(move_note_module, "call_get", boom) + monkeypatch.setattr(clients_mod, "ProjectClient", MockProjectClient) + + move_note_module = importlib.import_module("basic_memory.mcp.tools.move_note") result = await move_note_module._detect_cross_project_move_attempt( client=None, diff --git a/tests/mcp/test_tool_read_note.py b/tests/mcp/test_tool_read_note.py index 499f1d6c..cb15db6e 100644 --- a/tests/mcp/test_tool_read_note.py +++ b/tests/mcp/test_tool_read_note.py @@ -33,19 +33,20 @@ async def test_read_note_title_search_fallback_fetches_by_permalink(monkeypatch, ) import importlib - - read_note_module = importlib.import_module("basic_memory.mcp.tools.read_note") - from basic_memory.mcp.tools.utils import resolve_entity_id as real_resolve_entity_id from basic_memory.schemas.memory import memory_url_path + clients_mod = importlib.import_module("basic_memory.mcp.clients") + OriginalKnowledgeClient = clients_mod.KnowledgeClient direct_identifier = memory_url_path("Fallback Title Note") - async def selective_resolve(client, project_id, identifier: str) -> int: - if identifier == direct_identifier: - raise RuntimeError("force direct lookup failure") - return await real_resolve_entity_id(client, project_id, identifier) + class SelectiveKnowledgeClient(OriginalKnowledgeClient): + async def resolve_entity(self, identifier: str) -> int: + # Fail on the direct identifier to force fallback to title search + if identifier == direct_identifier: + raise RuntimeError("force direct lookup failure") + return await super().resolve_entity(identifier) - monkeypatch.setattr(read_note_module, "resolve_entity_id", selective_resolve) + monkeypatch.setattr(clients_mod, "KnowledgeClient", SelectiveKnowledgeClient) content = await read_note.fn("Fallback Title Note", project=test_project.name) assert "fallback content" in content @@ -59,6 +60,8 @@ async def test_read_note_returns_related_results_when_text_search_finds_matches( import importlib read_note_module = importlib.import_module("basic_memory.mcp.tools.read_note") + clients_mod = importlib.import_module("basic_memory.mcp.clients") + OriginalKnowledgeClient = clients_mod.KnowledgeClient async def fake_search_notes_fn(*, query, search_type, **kwargs): if search_type == "title": @@ -88,10 +91,11 @@ async def fake_search_notes_fn(*, query, search_type, **kwargs): ) # Ensure direct resolution doesn't short-circuit the fallback logic. - async def boom(*args, **kwargs): - raise RuntimeError("force fallback") + class FailingKnowledgeClient(OriginalKnowledgeClient): + async def resolve_entity(self, identifier: str) -> int: + raise RuntimeError("force fallback") - monkeypatch.setattr(read_note_module, "resolve_entity_id", boom) + monkeypatch.setattr(clients_mod, "KnowledgeClient", FailingKnowledgeClient) monkeypatch.setattr(read_note_module.search_notes, "fn", fake_search_notes_fn) result = await read_note.fn("missing-note", project=test_project.name) diff --git a/tests/mcp/test_tool_search.py b/tests/mcp/test_tool_search.py index 89596835..0899be13 100644 --- a/tests/mcp/test_tool_search.py +++ b/tests/mcp/test_tool_search.py @@ -292,6 +292,7 @@ async def test_search_notes_exception_handling(self, monkeypatch): import importlib search_mod = importlib.import_module("basic_memory.mcp.tools.search") + clients_mod = importlib.import_module("basic_memory.mcp.clients") class StubProject: project_url = "http://test" @@ -302,11 +303,17 @@ class StubProject: async def fake_get_active_project(*args, **kwargs): return StubProject() - async def fake_call_post(*args, **kwargs): - raise Exception("syntax error") + # Mock SearchClient to raise an exception + class MockSearchClient: + def __init__(self, *args, **kwargs): + pass + + async def search(self, *args, **kwargs): + raise Exception("syntax error") monkeypatch.setattr(search_mod, "get_active_project", fake_get_active_project) - monkeypatch.setattr(search_mod, "call_post", fake_call_post) + # Patch at the clients module level where the import happens + monkeypatch.setattr(clients_mod, "SearchClient", MockSearchClient) result = await search_mod.search_notes.fn(project="test-project", query="test query") assert isinstance(result, str) @@ -318,6 +325,7 @@ async def test_search_notes_permission_error(self, monkeypatch): import importlib search_mod = importlib.import_module("basic_memory.mcp.tools.search") + clients_mod = importlib.import_module("basic_memory.mcp.clients") class StubProject: project_url = "http://test" @@ -328,11 +336,17 @@ class StubProject: async def fake_get_active_project(*args, **kwargs): return StubProject() - async def fake_call_post(*args, **kwargs): - raise Exception("permission denied") + # Mock SearchClient to raise a permission error + class MockSearchClient: + def __init__(self, *args, **kwargs): + pass + + async def search(self, *args, **kwargs): + raise Exception("permission denied") monkeypatch.setattr(search_mod, "get_active_project", fake_get_active_project) - monkeypatch.setattr(search_mod, "call_post", fake_call_post) + # Patch at the clients module level where the import happens + monkeypatch.setattr(clients_mod, "SearchClient", MockSearchClient) result = await search_mod.search_notes.fn(project="test-project", query="test query") assert isinstance(result, str) diff --git a/tests/sync/test_coordinator.py b/tests/sync/test_coordinator.py new file mode 100644 index 00000000..fec07c35 --- /dev/null +++ b/tests/sync/test_coordinator.py @@ -0,0 +1,135 @@ +"""Tests for SyncCoordinator - centralized sync/watch lifecycle.""" + +import pytest +from unittest.mock import AsyncMock, patch + +from basic_memory.config import BasicMemoryConfig +from basic_memory.sync.coordinator import SyncCoordinator, SyncStatus + + +class TestSyncCoordinator: + """Test SyncCoordinator class.""" + + @pytest.fixture + def mock_config(self): + """Create a mock config for testing.""" + return BasicMemoryConfig() + + def test_initial_status(self, mock_config): + """Coordinator starts in NOT_STARTED state.""" + coordinator = SyncCoordinator(config=mock_config) + assert coordinator.status == SyncStatus.NOT_STARTED + assert coordinator.is_running is False + + @pytest.mark.asyncio + async def test_start_when_sync_disabled(self, mock_config): + """When should_sync is False, start() sets status to STOPPED.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=False, + skip_reason="Test skip", + ) + + await coordinator.start() + + assert coordinator.status == SyncStatus.STOPPED + assert coordinator.is_running is False + + @pytest.mark.asyncio + async def test_stop_when_not_started(self, mock_config): + """Stop is safe to call when not started.""" + coordinator = SyncCoordinator(config=mock_config) + + await coordinator.stop() # Should not raise + + assert coordinator.status == SyncStatus.NOT_STARTED + + @pytest.mark.asyncio + async def test_stop_when_stopped(self, mock_config): + """Stop is idempotent when already stopped.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=False, + ) + await coordinator.start() # Sets to STOPPED + + await coordinator.stop() # Should not raise + + assert coordinator.status == SyncStatus.STOPPED + + def test_get_status_info(self, mock_config): + """get_status_info returns diagnostic info.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=True, + skip_reason=None, + ) + + info = coordinator.get_status_info() + + assert info["status"] == "NOT_STARTED" + assert info["should_sync"] is True + assert info["skip_reason"] is None + assert info["has_task"] is False + + def test_get_status_info_with_skip_reason(self, mock_config): + """get_status_info includes skip reason.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=False, + skip_reason="Test environment detected", + ) + + info = coordinator.get_status_info() + + assert info["should_sync"] is False + assert info["skip_reason"] == "Test environment detected" + + @pytest.mark.asyncio + async def test_start_creates_task(self, mock_config): + """When should_sync is True, start() creates a background task.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=True, + ) + + # Mock initialize_file_sync to avoid actually starting sync + # The import happens inside start(), so patch at the source module + with patch( + "basic_memory.services.initialization.initialize_file_sync", + new_callable=AsyncMock, + ): + # Start coordinator + await coordinator.start() + + # Should be running with a task + assert coordinator.status == SyncStatus.RUNNING + assert coordinator.is_running is True + assert coordinator._sync_task is not None + + # Stop to clean up + await coordinator.stop() + + assert coordinator.status == SyncStatus.STOPPED + assert coordinator._sync_task is None + + @pytest.mark.asyncio + async def test_start_already_running(self, mock_config): + """Starting when already running is a no-op.""" + coordinator = SyncCoordinator( + config=mock_config, + should_sync=True, + ) + + with patch( + "basic_memory.services.initialization.initialize_file_sync", + new_callable=AsyncMock, + ): + await coordinator.start() + first_task = coordinator._sync_task + + # Start again - should not create new task + await coordinator.start() + assert coordinator._sync_task is first_task + + await coordinator.stop() diff --git a/tests/test_deps.py b/tests/test_deps.py index 79dba6e2..815819b2 100644 --- a/tests/test_deps.py +++ b/tests/test_deps.py @@ -8,6 +8,7 @@ from fastapi import HTTPException from basic_memory.deps import get_project_config, get_project_id +from basic_memory.deps.projects import validate_project_id from basic_memory.models.project import Project from basic_memory.repository.project_repository import ProjectRepository @@ -204,3 +205,28 @@ async def test_get_project_config_case_sensitivity( # All should resolve to the same project assert config1.name == config2.name == config3.name == "My Test Project" assert config1.home == config2.home == config3.home == Path("/my/test/project") + + +# --- Tests for validate_project_id (v2 API) --- + + +@pytest.mark.asyncio +async def test_validate_project_id_success( + project_repository: ProjectRepository, test_project: Project +): + """Test that validate_project_id returns project_id when project exists.""" + project_id = await validate_project_id( + project_id=test_project.id, project_repository=project_repository + ) + + assert project_id == test_project.id + + +@pytest.mark.asyncio +async def test_validate_project_id_not_found(project_repository: ProjectRepository): + """Test that validate_project_id raises HTTPException when project not found.""" + with pytest.raises(HTTPException) as exc_info: + await validate_project_id(project_id=99999, project_repository=project_repository) + + assert exc_info.value.status_code == 404 + assert "Project with ID 99999 not found" in exc_info.value.detail diff --git a/tests/test_project_resolver.py b/tests/test_project_resolver.py new file mode 100644 index 00000000..1d4b4e66 --- /dev/null +++ b/tests/test_project_resolver.py @@ -0,0 +1,210 @@ +"""Tests for ProjectResolver - unified project resolution logic.""" + +import os +import pytest +from basic_memory.project_resolver import ( + ProjectResolver, + ResolvedProject, + ResolutionMode, +) + + +class TestProjectResolver: + """Test ProjectResolver class.""" + + def test_cloud_mode_requires_project(self): + """In cloud mode, project is required.""" + resolver = ProjectResolver(cloud_mode=True) + with pytest.raises(ValueError, match="Project is required for cloud mode"): + resolver.resolve(project=None) + + def test_cloud_mode_with_explicit_project(self): + """In cloud mode, explicit project is accepted.""" + resolver = ProjectResolver(cloud_mode=True) + result = resolver.resolve(project="my-project") + + assert result.project == "my-project" + assert result.mode == ResolutionMode.CLOUD_EXPLICIT + assert result.is_resolved is True + assert result.is_discovery_mode is False + + def test_cloud_mode_discovery_allowed(self): + """In cloud mode with allow_discovery, None is acceptable.""" + resolver = ProjectResolver(cloud_mode=True) + result = resolver.resolve(project=None, allow_discovery=True) + + assert result.project is None + assert result.mode == ResolutionMode.CLOUD_DISCOVERY + assert result.is_resolved is False + assert result.is_discovery_mode is True + + def test_local_mode_env_constraint_priority(self, monkeypatch): + """Env constraint has highest priority in local mode.""" + monkeypatch.setenv("BASIC_MEMORY_MCP_PROJECT", "constrained-project") + resolver = ProjectResolver.from_env( + cloud_mode=False, + default_project_mode=True, + default_project="default-project", + ) + + # Even with explicit project and default, env constraint wins + result = resolver.resolve(project="explicit-project") + + assert result.project == "constrained-project" + assert result.mode == ResolutionMode.ENV_CONSTRAINT + assert result.is_resolved is True + + def test_local_mode_explicit_project(self): + """Explicit project parameter has second priority.""" + resolver = ProjectResolver( + cloud_mode=False, + default_project_mode=True, + default_project="default-project", + ) + + result = resolver.resolve(project="explicit-project") + + assert result.project == "explicit-project" + assert result.mode == ResolutionMode.EXPLICIT + + def test_local_mode_default_project(self): + """Default project is used when default_project_mode is true.""" + resolver = ProjectResolver( + cloud_mode=False, + default_project_mode=True, + default_project="my-default", + ) + + result = resolver.resolve(project=None) + + assert result.project == "my-default" + assert result.mode == ResolutionMode.DEFAULT + + def test_local_mode_no_default_when_mode_disabled(self): + """Default project is NOT used when default_project_mode is false.""" + resolver = ProjectResolver( + cloud_mode=False, + default_project_mode=False, + default_project="my-default", + ) + + result = resolver.resolve(project=None) + + assert result.project is None + assert result.mode == ResolutionMode.NONE + assert result.is_resolved is False + + def test_local_mode_no_resolution_possible(self): + """When nothing is configured, resolution returns None.""" + resolver = ProjectResolver(cloud_mode=False) + result = resolver.resolve(project=None) + + assert result.project is None + assert result.mode == ResolutionMode.NONE + assert "default_project_mode is disabled" in result.reason + + def test_require_project_success(self): + """require_project returns result when project resolved.""" + resolver = ProjectResolver( + cloud_mode=False, + default_project_mode=True, + default_project="required-project", + ) + + result = resolver.require_project() + + assert result.project == "required-project" + assert result.is_resolved is True + + def test_require_project_raises_on_failure(self): + """require_project raises ValueError when not resolved.""" + resolver = ProjectResolver(cloud_mode=False, default_project_mode=False) + + with pytest.raises(ValueError, match="No project specified"): + resolver.require_project() + + def test_require_project_custom_error_message(self): + """require_project uses custom error message.""" + resolver = ProjectResolver(cloud_mode=False, default_project_mode=False) + + with pytest.raises(ValueError, match="Custom error message"): + resolver.require_project(error_message="Custom error message") + + def test_from_env_without_env_var(self, monkeypatch): + """from_env without BASIC_MEMORY_MCP_PROJECT set.""" + monkeypatch.delenv("BASIC_MEMORY_MCP_PROJECT", raising=False) + resolver = ProjectResolver.from_env( + cloud_mode=False, + default_project_mode=True, + default_project="test", + ) + + assert resolver.constrained_project is None + result = resolver.resolve(project="explicit") + assert result.mode == ResolutionMode.EXPLICIT + + def test_from_env_with_env_var(self, monkeypatch): + """from_env with BASIC_MEMORY_MCP_PROJECT set.""" + monkeypatch.setenv("BASIC_MEMORY_MCP_PROJECT", "env-project") + resolver = ProjectResolver.from_env() + + assert resolver.constrained_project == "env-project" + + +class TestResolvedProject: + """Test ResolvedProject dataclass.""" + + def test_is_resolved_true(self): + """is_resolved returns True when project is set.""" + result = ResolvedProject( + project="test", + mode=ResolutionMode.EXPLICIT, + reason="test", + ) + assert result.is_resolved is True + + def test_is_resolved_false(self): + """is_resolved returns False when project is None.""" + result = ResolvedProject( + project=None, + mode=ResolutionMode.NONE, + reason="test", + ) + assert result.is_resolved is False + + def test_is_discovery_mode_cloud(self): + """is_discovery_mode is True for CLOUD_DISCOVERY.""" + result = ResolvedProject( + project=None, + mode=ResolutionMode.CLOUD_DISCOVERY, + reason="test", + ) + assert result.is_discovery_mode is True + + def test_is_discovery_mode_none(self): + """is_discovery_mode is True for NONE with no project.""" + result = ResolvedProject( + project=None, + mode=ResolutionMode.NONE, + reason="test", + ) + assert result.is_discovery_mode is True + + def test_is_discovery_mode_false(self): + """is_discovery_mode is False when project is resolved.""" + result = ResolvedProject( + project="test", + mode=ResolutionMode.EXPLICIT, + reason="test", + ) + assert result.is_discovery_mode is False + + def test_frozen_dataclass(self): + """ResolvedProject is immutable.""" + result = ResolvedProject( + project="test", + mode=ResolutionMode.EXPLICIT, + reason="test", + ) + with pytest.raises(AttributeError): + result.project = "changed" # type: ignore diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 00000000..42cdc99d --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,53 @@ +"""Tests for runtime mode resolution.""" + + +from basic_memory.runtime import RuntimeMode, resolve_runtime_mode + + +class TestRuntimeMode: + """Tests for RuntimeMode enum.""" + + def test_local_mode_properties(self): + mode = RuntimeMode.LOCAL + assert mode.is_local is True + assert mode.is_cloud is False + assert mode.is_test is False + + def test_cloud_mode_properties(self): + mode = RuntimeMode.CLOUD + assert mode.is_local is False + assert mode.is_cloud is True + assert mode.is_test is False + + def test_test_mode_properties(self): + mode = RuntimeMode.TEST + assert mode.is_local is False + assert mode.is_cloud is False + assert mode.is_test is True + + +class TestResolveRuntimeMode: + """Tests for resolve_runtime_mode function.""" + + def test_resolves_to_test_when_test_env(self): + """Test environment takes precedence over cloud mode.""" + mode = resolve_runtime_mode(cloud_mode_enabled=True, is_test_env=True) + assert mode == RuntimeMode.TEST + + def test_resolves_to_cloud_when_enabled(self): + """Cloud mode is used when enabled and not in test env.""" + mode = resolve_runtime_mode(cloud_mode_enabled=True, is_test_env=False) + assert mode == RuntimeMode.CLOUD + + def test_resolves_to_local_by_default(self): + """Local mode is the default when no other modes apply.""" + mode = resolve_runtime_mode(cloud_mode_enabled=False, is_test_env=False) + assert mode == RuntimeMode.LOCAL + + def test_test_env_overrides_cloud_mode(self): + """Test environment should override cloud mode.""" + # When both are enabled, test takes precedence + mode = resolve_runtime_mode(cloud_mode_enabled=True, is_test_env=True) + assert mode == RuntimeMode.TEST + assert mode.is_test is True + assert mode.is_cloud is False