diff --git a/src/basic_memory/cli/commands/__init__.py b/src/basic_memory/cli/commands/__init__.py index 1a8ae209..8b98c81c 100644 --- a/src/basic_memory/cli/commands/__init__.py +++ b/src/basic_memory/cli/commands/__init__.py @@ -1,7 +1,7 @@ """CLI commands for basic-memory.""" from . import status, db, import_memory_json, mcp, import_claude_conversations -from . import import_claude_projects, import_chatgpt, tool, project +from . import import_claude_projects, import_chatgpt, tool, project, format __all__ = [ "status", @@ -13,4 +13,5 @@ "import_chatgpt", "tool", "project", + "format", ] diff --git a/src/basic_memory/cli/commands/format.py b/src/basic_memory/cli/commands/format.py new file mode 100644 index 00000000..4aed2a02 --- /dev/null +++ b/src/basic_memory/cli/commands/format.py @@ -0,0 +1,183 @@ +"""Format command for basic-memory CLI.""" + +import asyncio +from pathlib import Path +from typing import Annotated, Optional + +import typer +from loguru import logger +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from basic_memory.cli.app import app +from basic_memory.config import ConfigManager, get_project_config +from basic_memory.file_utils import format_file + +console = Console() + + +async def format_single_file(file_path: Path, app_config) -> tuple[Path, bool, Optional[str]]: + """Format a single file. + + Returns: + Tuple of (path, success, error_message) + """ + try: + result = await format_file(file_path, app_config) + if result is not None: + return (file_path, True, None) + else: + return (file_path, False, "No formatter configured or formatting skipped") + except Exception as e: + return (file_path, False, str(e)) + + +async def format_files( + paths: list[Path], app_config, show_progress: bool = True +) -> tuple[int, int, list[tuple[Path, str]]]: + """Format multiple files. + + Returns: + Tuple of (formatted_count, skipped_count, errors) + """ + formatted = 0 + skipped = 0 + errors: list[tuple[Path, str]] = [] + + if show_progress: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Formatting files...", total=len(paths)) + + for file_path in paths: + path, success, error = await format_single_file(file_path, app_config) + if success: + formatted += 1 + elif error and "No formatter configured" not in error: + errors.append((path, error)) + else: + skipped += 1 + progress.update(task, advance=1) + else: + for file_path in paths: + path, success, error = await format_single_file(file_path, app_config) + if success: + formatted += 1 + elif error and "No formatter configured" not in error: + errors.append((path, error)) + else: + skipped += 1 + + return formatted, skipped, errors + + +async def run_format( + path: Optional[Path] = None, + project: Optional[str] = None, +) -> None: + """Run the format command.""" + app_config = ConfigManager().config + + # Check if formatting is enabled + if not app_config.format_on_save and not app_config.formatter_command and not app_config.formatters: + console.print( + "[yellow]No formatters configured. Set format_on_save=true and " + "formatter_command or formatters in your config.[/yellow]" + ) + console.print( + "\nExample config (~/.basic-memory/config.json):\n" + ' "format_on_save": true,\n' + ' "formatter_command": "prettier --write {file}"\n' + ) + raise typer.Exit(1) + + # Temporarily enable format_on_save for this command + # (so format_file actually runs the formatter) + original_format_on_save = app_config.format_on_save + app_config.format_on_save = True + + try: + # Determine which files to format + if path: + # Format specific file or directory + if path.is_file(): + files = [path] + elif path.is_dir(): + # Find all markdown and json files + files = list(path.rglob("*.md")) + list(path.rglob("*.json")) + list(path.rglob("*.canvas")) + else: + console.print(f"[red]Path not found: {path}[/red]") + raise typer.Exit(1) + else: + # Format all files in project + project_config = get_project_config(project) + project_path = Path(project_config.home) + + if not project_path.exists(): + console.print(f"[red]Project path not found: {project_path}[/red]") + raise typer.Exit(1) + + # Find all markdown and json files + files = ( + list(project_path.rglob("*.md")) + + list(project_path.rglob("*.json")) + + list(project_path.rglob("*.canvas")) + ) + + if not files: + console.print("[yellow]No files found to format.[/yellow]") + return + + console.print(f"Found {len(files)} file(s) to format...") + + formatted, skipped, errors = await format_files(files, app_config) + + # Print summary + console.print() + if formatted > 0: + console.print(f"[green]Formatted: {formatted} file(s)[/green]") + if skipped > 0: + console.print(f"[dim]Skipped: {skipped} file(s) (no formatter for extension)[/dim]") + if errors: + console.print(f"[red]Errors: {len(errors)} file(s)[/red]") + for path, error in errors: + console.print(f" [red]{path}[/red]: {error}") + + finally: + # Restore original setting + app_config.format_on_save = original_format_on_save + + +@app.command() +def format( + path: Annotated[ + Optional[Path], + typer.Argument(help="File or directory to format. Defaults to current project."), + ] = None, + project: Annotated[ + Optional[str], + typer.Option("--project", "-p", help="Project name to format."), + ] = None, +) -> None: + """Format files using configured formatters. + + Uses the formatter_command or formatters settings from your config. + By default, formats all .md, .json, and .canvas files in the current project. + + Examples: + basic-memory format # Format all files in current project + basic-memory format --project research # Format files in specific project + basic-memory format notes/meeting.md # Format a specific file + basic-memory format notes/ # Format all files in directory + """ + try: + asyncio.run(run_format(path, project)) + except Exception as e: + if not isinstance(e, typer.Exit): + logger.error(f"Error formatting files: {e}") + console.print(f"[red]Error formatting files: {e}[/red]") + raise typer.Exit(code=1) + raise diff --git a/src/basic_memory/cli/commands/import_chatgpt.py b/src/basic_memory/cli/commands/import_chatgpt.py index 045d1eff..7a55fdb7 100644 --- a/src/basic_memory/cli/commands/import_chatgpt.py +++ b/src/basic_memory/cli/commands/import_chatgpt.py @@ -7,7 +7,7 @@ import typer from basic_memory.cli.app import import_app -from basic_memory.config import get_project_config +from basic_memory.config import ConfigManager, get_project_config from basic_memory.importers import ChatGPTImporter from basic_memory.markdown import EntityParser, MarkdownProcessor from loguru import logger @@ -20,8 +20,9 @@ async def get_markdown_processor() -> MarkdownProcessor: """Get MarkdownProcessor instance.""" config = get_project_config() + app_config = ConfigManager().config entity_parser = EntityParser(config.home) - return MarkdownProcessor(entity_parser) + return MarkdownProcessor(entity_parser, app_config=app_config) @import_app.command(name="chatgpt", help="Import conversations from ChatGPT JSON export.") diff --git a/src/basic_memory/cli/commands/import_claude_conversations.py b/src/basic_memory/cli/commands/import_claude_conversations.py index 7a89da17..08362a33 100644 --- a/src/basic_memory/cli/commands/import_claude_conversations.py +++ b/src/basic_memory/cli/commands/import_claude_conversations.py @@ -7,7 +7,7 @@ import typer from basic_memory.cli.app import claude_app -from basic_memory.config import get_project_config +from basic_memory.config import ConfigManager, get_project_config from basic_memory.importers.claude_conversations_importer import ClaudeConversationsImporter from basic_memory.markdown import EntityParser, MarkdownProcessor from loguru import logger @@ -20,8 +20,9 @@ async def get_markdown_processor() -> MarkdownProcessor: """Get MarkdownProcessor instance.""" config = get_project_config() + app_config = ConfigManager().config entity_parser = EntityParser(config.home) - return MarkdownProcessor(entity_parser) + return MarkdownProcessor(entity_parser, app_config=app_config) @claude_app.command(name="conversations", help="Import chat conversations from Claude.ai.") diff --git a/src/basic_memory/cli/commands/import_claude_projects.py b/src/basic_memory/cli/commands/import_claude_projects.py index 553f15eb..dfbb331b 100644 --- a/src/basic_memory/cli/commands/import_claude_projects.py +++ b/src/basic_memory/cli/commands/import_claude_projects.py @@ -7,7 +7,7 @@ import typer from basic_memory.cli.app import claude_app -from basic_memory.config import get_project_config +from basic_memory.config import ConfigManager, get_project_config from basic_memory.importers.claude_projects_importer import ClaudeProjectsImporter from basic_memory.markdown import EntityParser, MarkdownProcessor from loguru import logger @@ -20,8 +20,9 @@ async def get_markdown_processor() -> MarkdownProcessor: """Get MarkdownProcessor instance.""" config = get_project_config() + app_config = ConfigManager().config entity_parser = EntityParser(config.home) - return MarkdownProcessor(entity_parser) + return MarkdownProcessor(entity_parser, app_config=app_config) @claude_app.command(name="projects", help="Import projects from Claude.ai.") diff --git a/src/basic_memory/cli/commands/import_memory_json.py b/src/basic_memory/cli/commands/import_memory_json.py index 84bf73a8..cc6d2d6b 100644 --- a/src/basic_memory/cli/commands/import_memory_json.py +++ b/src/basic_memory/cli/commands/import_memory_json.py @@ -7,7 +7,7 @@ import typer from basic_memory.cli.app import import_app -from basic_memory.config import get_project_config +from basic_memory.config import ConfigManager, get_project_config from basic_memory.importers.memory_json_importer import MemoryJsonImporter from basic_memory.markdown import EntityParser, MarkdownProcessor from loguru import logger @@ -20,8 +20,9 @@ async def get_markdown_processor() -> MarkdownProcessor: """Get MarkdownProcessor instance.""" config = get_project_config() + app_config = ConfigManager().config entity_parser = EntityParser(config.home) - return MarkdownProcessor(entity_parser) + return MarkdownProcessor(entity_parser, app_config=app_config) @import_app.command() diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index fc382f06..5fed89e8 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -165,6 +165,28 @@ class BasicMemoryConfig(BaseSettings): description="Skip expensive initialization synchronization. Useful for cloud/stateless deployments where project reconciliation is not needed.", ) + # File formatting configuration + format_on_save: bool = Field( + default=False, + description="Automatically format files after saving using configured formatter. Disabled by default.", + ) + + formatter_command: Optional[str] = Field( + default="npx prettier --write {file}", + description="Default formatter command. Use {file} as placeholder for file path. Uses npx prettier by default.", + ) + + formatters: Dict[str, str] = Field( + default_factory=dict, + description="Per-extension formatters. Keys are extensions (without dot), values are commands. Example: {'md': 'prettier --write {file}', 'json': 'prettier --write {file}'}", + ) + + formatter_timeout: float = Field( + default=5.0, + description="Maximum seconds to wait for formatter to complete", + gt=0, + ) + # Project path constraints project_root: Optional[str] = Field( default=None, diff --git a/src/basic_memory/deps.py b/src/basic_memory/deps.py index 8a2f42a8..0117b982 100644 --- a/src/basic_memory/deps.py +++ b/src/basic_memory/deps.py @@ -351,24 +351,32 @@ async def get_entity_parser_v2(project_config: ProjectConfigV2Dep) -> EntityPars EntityParserV2Dep = Annotated["EntityParser", Depends(get_entity_parser_v2)] -async def get_markdown_processor(entity_parser: EntityParserDep) -> MarkdownProcessor: - return MarkdownProcessor(entity_parser) +async def get_markdown_processor( + entity_parser: EntityParserDep, app_config: AppConfigDep +) -> MarkdownProcessor: + return MarkdownProcessor(entity_parser, app_config=app_config) MarkdownProcessorDep = Annotated[MarkdownProcessor, Depends(get_markdown_processor)] -async def get_markdown_processor_v2(entity_parser: EntityParserV2Dep) -> MarkdownProcessor: - return MarkdownProcessor(entity_parser) +async def get_markdown_processor_v2( + entity_parser: EntityParserV2Dep, app_config: AppConfigDep +) -> MarkdownProcessor: + return MarkdownProcessor(entity_parser, app_config=app_config) MarkdownProcessorV2Dep = Annotated[MarkdownProcessor, Depends(get_markdown_processor_v2)] async def get_file_service( - project_config: ProjectConfigDep, markdown_processor: MarkdownProcessorDep + project_config: ProjectConfigDep, + markdown_processor: MarkdownProcessorDep, + app_config: AppConfigDep, ) -> FileService: - file_service = FileService(project_config.home, markdown_processor) + file_service = FileService( + project_config.home, markdown_processor, app_config=app_config + ) logger.debug( f"Created FileService for project: {project_config.name}, base_path: {project_config.home} " ) @@ -379,9 +387,13 @@ async def get_file_service( async def get_file_service_v2( - project_config: ProjectConfigV2Dep, markdown_processor: MarkdownProcessorV2Dep + project_config: ProjectConfigV2Dep, + markdown_processor: MarkdownProcessorV2Dep, + app_config: AppConfigDep, ) -> FileService: - file_service = FileService(project_config.home, markdown_processor) + file_service = FileService( + project_config.home, markdown_processor, app_config=app_config + ) logger.debug( f"Created FileService for project: {project_config.name}, base_path: {project_config.home}" ) diff --git a/src/basic_memory/file_utils.py b/src/basic_memory/file_utils.py index 83ae5c32..773df11b 100644 --- a/src/basic_memory/file_utils.py +++ b/src/basic_memory/file_utils.py @@ -1,11 +1,13 @@ """Utilities for file operations.""" +import asyncio import hashlib +import shlex from dataclasses import dataclass from datetime import datetime from pathlib import Path import re -from typing import Any, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import aiofiles import yaml @@ -14,6 +16,9 @@ from basic_memory.utils import FilePath +if TYPE_CHECKING: + from basic_memory.config import BasicMemoryConfig + @dataclass class FileMetadata: @@ -100,6 +105,101 @@ async def write_file_atomic(path: FilePath, content: str) -> None: raise FileWriteError(f"Failed to write file {path}: {e}") +async def format_file( + path: Path, + config: "BasicMemoryConfig", +) -> Optional[str]: + """ + Format a file using configured formatter. + + Runs an external formatter (like prettier) on the file after it has been written. + The formatter command is determined by file extension, falling back to a global + formatter if no extension-specific one is configured. + + Args: + path: File to format + config: Configuration with formatter settings + + Returns: + Formatted content if successful, None if formatting was skipped or failed. + Failures are logged as warnings but don't raise exceptions. + """ + if not config.format_on_save: + return None + + extension = path.suffix.lstrip(".") + formatter = config.formatters.get(extension) or config.formatter_command + + if not formatter: + logger.debug("No formatter configured for extension", extension=extension) + return None + + # Replace {file} placeholder with the actual path + cmd = formatter.replace("{file}", str(path)) + + try: + # Parse command into args list for safer execution (no shell=True) + args = shlex.split(cmd) + + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), + timeout=config.formatter_timeout, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + logger.warning( + "Formatter timed out", + path=str(path), + timeout=config.formatter_timeout, + ) + return None + + if proc.returncode != 0: + logger.warning( + "Formatter exited with non-zero status", + path=str(path), + returncode=proc.returncode, + stderr=stderr.decode("utf-8", errors="replace") if stderr else "", + ) + # Still try to read the file - formatter may have partially worked + # or the file may be unchanged + + # Read formatted content + async with aiofiles.open(path, mode="r", encoding="utf-8") as f: + formatted_content = await f.read() + + logger.debug( + "Formatted file successfully", + path=str(path), + formatter=args[0] if args else formatter, + ) + return formatted_content + + except FileNotFoundError: + # Formatter executable not found + logger.warning( + "Formatter executable not found", + command=cmd.split()[0] if cmd else "", + path=str(path), + ) + return None + except Exception as e: + logger.warning( + "Formatter failed", + path=str(path), + error=str(e), + ) + return None + + def has_frontmatter(content: str) -> bool: """ Check if content contains valid YAML frontmatter. diff --git a/src/basic_memory/markdown/markdown_processor.py b/src/basic_memory/markdown/markdown_processor.py index 42ca476d..2630bff4 100644 --- a/src/basic_memory/markdown/markdown_processor.py +++ b/src/basic_memory/markdown/markdown_processor.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from collections import OrderedDict from frontmatter import Post @@ -11,6 +11,9 @@ from basic_memory.markdown.entity_parser import EntityParser from basic_memory.markdown.schemas import EntityMarkdown, Observation, Relation +if TYPE_CHECKING: + from basic_memory.config import BasicMemoryConfig + class DirtyFileError(Exception): """Raised when attempting to write to a file that has been modified.""" @@ -36,9 +39,14 @@ class MarkdownProcessor: 3. Track schema changes (that's done by the database) """ - def __init__(self, entity_parser: EntityParser): - """Initialize processor with base path and parser.""" + def __init__( + self, + entity_parser: EntityParser, + app_config: Optional["BasicMemoryConfig"] = None, + ): + """Initialize processor with parser and optional config.""" self.entity_parser = entity_parser + self.app_config = app_config async def read_file(self, path: Path) -> EntityMarkdown: """Read and parse file into EntityMarkdown schema. @@ -123,7 +131,15 @@ async def write_file( # Write atomically and return checksum of updated file path.parent.mkdir(parents=True, exist_ok=True) await file_utils.write_file_atomic(path, final_content) - return await file_utils.compute_checksum(final_content) + + # Format file if configured + content_for_checksum = final_content + if self.app_config: + formatted_content = await file_utils.format_file(path, self.app_config) + if formatted_content is not None: + content_for_checksum = formatted_content + + return await file_utils.compute_checksum(content_for_checksum) def format_observations(self, observations: list[Observation]) -> str: """Format observations section in standard way. diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index 4b738372..fdc412cf 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -5,13 +5,16 @@ import mimetypes from datetime import datetime from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import aiofiles import yaml from basic_memory import file_utils + +if TYPE_CHECKING: + from basic_memory.config import BasicMemoryConfig from basic_memory.file_utils import FileError, FileMetadata, ParseError from basic_memory.markdown.markdown_processor import MarkdownProcessor from basic_memory.models import Entity as EntityModel @@ -42,9 +45,11 @@ def __init__( base_path: Path, markdown_processor: MarkdownProcessor, max_concurrent_files: int = 10, + app_config: Optional["BasicMemoryConfig"] = None, ): self.base_path = base_path.resolve() # Get absolute path self.markdown_processor = markdown_processor + self.app_config = app_config # Semaphore to limit concurrent file operations # Prevents OOM on large projects by processing files in batches self._file_semaphore = asyncio.Semaphore(max_concurrent_files) @@ -149,12 +154,15 @@ async def write_file(self, path: FilePath, content: str) -> str: Handles both absolute and relative paths. Relative paths are resolved against base_path. + If format_on_save is enabled in config, runs the configured formatter + after writing and returns the checksum of the formatted content. + Args: path: Where to write (Path or string) content: Content to write Returns: - Checksum of written content + Checksum of written content (or formatted content if formatting enabled) Raises: FileOperationError: If write fails @@ -177,8 +185,15 @@ async def write_file(self, path: FilePath, content: str) -> str: await file_utils.write_file_atomic(full_path, content) - # Compute and return checksum - checksum = await file_utils.compute_checksum(content) + # Format file if configured + final_content = content + if self.app_config: + formatted_content = await file_utils.format_file(full_path, self.app_config) + if formatted_content is not None: + final_content = formatted_content + + # Compute and return checksum of final content + checksum = await file_utils.compute_checksum(final_content) logger.debug(f"File write completed path={full_path}, {checksum=}") return checksum @@ -405,7 +420,15 @@ async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> s ) await file_utils.write_file_atomic(full_path, final_content) - return await file_utils.compute_checksum(final_content) + + # Format file if configured + content_for_checksum = final_content + if self.app_config: + formatted_content = await file_utils.format_file(full_path, self.app_config) + if formatted_content is not None: + content_for_checksum = formatted_content + + return await file_utils.compute_checksum(content_for_checksum) except Exception as e: # Only log real errors (not YAML parsing, which is handled above) diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index be75d867..2edf66d2 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -1210,8 +1210,8 @@ async def get_sync_service(project: Project) -> SyncService: # pragma: no cover project_path = Path(project.path) entity_parser = EntityParser(project_path) - markdown_processor = MarkdownProcessor(entity_parser) - file_service = FileService(project_path, markdown_processor) + markdown_processor = MarkdownProcessor(entity_parser, app_config=app_config) + file_service = FileService(project_path, markdown_processor, app_config=app_config) # Initialize repositories entity_repository = EntityRepository(session_maker, project_id=project.id) diff --git a/tests/cli/test_cli_tool_exit.py b/tests/cli/test_cli_tool_exit.py new file mode 100644 index 00000000..b645eaad --- /dev/null +++ b/tests/cli/test_cli_tool_exit.py @@ -0,0 +1,86 @@ +"""Test that CLI tool commands exit cleanly without hanging. + +This test ensures that CLI commands properly clean up database connections +on exit, preventing process hangs. See GitHub issue for details. + +The issue occurs when: +1. ensure_initialization() calls asyncio.run(initialize_app()) +2. initialize_app() creates global database connections via db.get_or_create_db() +3. When asyncio.run() completes, the event loop closes +4. But the global database engine holds async connections that prevent clean exit +5. Process hangs indefinitely + +The fix ensures db.shutdown_db() is called before asyncio.run() returns. +""" + +import subprocess +import sys + +import pytest + + +class TestCLIToolExit: + """Test that CLI tool commands exit cleanly.""" + + @pytest.mark.parametrize( + "command", + [ + ["tool", "--help"], + ["tool", "write-note", "--help"], + ["tool", "read-note", "--help"], + ["tool", "search-notes", "--help"], + ["tool", "build-context", "--help"], + ["status"], # Also affected by same issue + ], + ) + def test_cli_command_exits_cleanly(self, command: list[str]): + """Test that CLI commands exit without hanging. + + Each command should complete within the timeout without requiring + manual termination (Ctrl+C). + """ + full_command = [sys.executable, "-m", "basic_memory.cli.main"] + command + + try: + result = subprocess.run( + full_command, + capture_output=True, + text=True, + timeout=10.0, # 10 second timeout - commands should complete in ~2s + ) + # Command should exit with code 0 for --help + assert result.returncode == 0, f"Command failed: {result.stderr}" + except subprocess.TimeoutExpired: + pytest.fail( + f"Command '{' '.join(command)}' hung and did not exit within timeout. " + "This indicates database connections are not being cleaned up properly." + ) + + def test_ensure_initialization_exits_cleanly(self): + """Test that ensure_initialization doesn't cause process hang. + + This test directly tests the initialization function that's called + by CLI commands, ensuring it cleans up database connections properly. + """ + code = """ +import asyncio +from basic_memory.config import ConfigManager +from basic_memory.services.initialization import ensure_initialization + +app_config = ConfigManager().config +ensure_initialization(app_config) +print("OK") +""" + try: + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + timeout=10.0, + ) + assert "OK" in result.stdout, f"Unexpected output: {result.stdout}" + except subprocess.TimeoutExpired: + pytest.fail( + "ensure_initialization() caused process hang. " + "Database connections are not being cleaned up before event loop closes." + ) diff --git a/tests/test_config.py b/tests/test_config.py index 3014e126..f7abfea8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -483,3 +483,116 @@ def test_model_post_init_uses_platform_native_separators(self, config_home, monk else: # Unix: should have forward slashes assert "/" in main_path + + +class TestFormattingConfig: + """Test file formatting configuration options.""" + + def test_format_on_save_defaults_to_false(self): + """Test that format_on_save is disabled by default.""" + config = BasicMemoryConfig() + assert config.format_on_save is False + + def test_format_on_save_can_be_enabled(self): + """Test that format_on_save can be set to True.""" + config = BasicMemoryConfig(format_on_save=True) + assert config.format_on_save is True + + def test_formatter_command_defaults_to_npx_prettier(self): + """Test that formatter_command defaults to npx prettier.""" + config = BasicMemoryConfig() + assert config.formatter_command == "npx prettier --write {file}" + + def test_formatter_command_can_be_set(self): + """Test that formatter_command can be configured.""" + config = BasicMemoryConfig(formatter_command="prettier --write {file}") + assert config.formatter_command == "prettier --write {file}" + + def test_formatters_defaults_to_empty_dict(self): + """Test that formatters defaults to empty dict.""" + config = BasicMemoryConfig() + assert config.formatters == {} + + def test_formatters_can_be_configured(self): + """Test that per-extension formatters can be configured.""" + config = BasicMemoryConfig( + formatters={ + "md": "prettier --write {file}", + "json": "jq . {file} > {file}.tmp && mv {file}.tmp {file}", + } + ) + assert config.formatters["md"] == "prettier --write {file}" + assert "json" in config.formatters + + def test_formatter_timeout_defaults_to_5_seconds(self): + """Test that formatter_timeout defaults to 5.0 seconds.""" + config = BasicMemoryConfig() + assert config.formatter_timeout == 5.0 + + def test_formatter_timeout_can_be_customized(self): + """Test that formatter_timeout can be set to a different value.""" + config = BasicMemoryConfig(formatter_timeout=10.0) + assert config.formatter_timeout == 10.0 + + def test_formatter_timeout_must_be_positive(self): + """Test that formatter_timeout validation rejects non-positive values.""" + import pydantic + + with pytest.raises(pydantic.ValidationError): + BasicMemoryConfig(formatter_timeout=0) + + with pytest.raises(pydantic.ValidationError): + BasicMemoryConfig(formatter_timeout=-1) + + def test_formatting_env_vars(self, monkeypatch): + """Test that formatting config can be set via environment variables.""" + monkeypatch.setenv("BASIC_MEMORY_FORMAT_ON_SAVE", "true") + monkeypatch.setenv("BASIC_MEMORY_FORMATTER_COMMAND", "prettier --write {file}") + monkeypatch.setenv("BASIC_MEMORY_FORMATTER_TIMEOUT", "10.0") + + config = BasicMemoryConfig() + + assert config.format_on_save is True + assert config.formatter_command == "prettier --write {file}" + assert config.formatter_timeout == 10.0 + + def test_formatters_env_var_json(self, monkeypatch): + """Test that formatters dict can be set via JSON environment variable.""" + import json + + formatters_json = json.dumps({"md": "prettier --write {file}", "json": "jq . {file}"}) + monkeypatch.setenv("BASIC_MEMORY_FORMATTERS", formatters_json) + + config = BasicMemoryConfig() + + assert config.formatters == {"md": "prettier --write {file}", "json": "jq . {file}"} + + def test_save_and_load_formatting_config(self): + """Test that formatting config survives save/load cycle.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + config_manager = ConfigManager() + config_manager.config_dir = temp_path / "basic-memory" + config_manager.config_file = config_manager.config_dir / "config.json" + config_manager.config_dir.mkdir(parents=True, exist_ok=True) + + # Create config with formatting settings + test_config = BasicMemoryConfig( + projects={"main": str(temp_path / "main")}, + format_on_save=True, + formatter_command="prettier --write {file}", + formatters={"md": "prettier --write {file}", "json": "prettier --write {file}"}, + formatter_timeout=10.0, + ) + config_manager.save_config(test_config) + + # Load and verify + loaded_config = config_manager.load_config() + assert loaded_config.format_on_save is True + assert loaded_config.formatter_command == "prettier --write {file}" + assert loaded_config.formatters == { + "md": "prettier --write {file}", + "json": "prettier --write {file}", + } + assert loaded_config.formatter_timeout == 10.0 diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index 5e7b7b77..aee74b0c 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -6,11 +6,13 @@ import random import string +from basic_memory.config import BasicMemoryConfig from basic_memory.file_utils import ( FileError, FileWriteError, ParseError, compute_checksum, + format_file, has_frontmatter, parse_frontmatter, remove_frontmatter, @@ -214,3 +216,188 @@ def test_sanitize_for_filename_removes_invalid_characters(): ) def test_sanitize_for_folder_edge_cases(input_folder, expected): assert sanitize_for_folder(input_folder) == expected + + +# ============================================================================= +# format_file tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_format_file_disabled_by_default(tmp_path: Path): + """Test that format_file returns None when format_on_save is False (default).""" + test_file = tmp_path / "test.md" + test_file.write_text("# Test\n") + + config = BasicMemoryConfig() + assert config.format_on_save is False + + result = await format_file(test_file, config) + assert result is None + + +@pytest.mark.asyncio +async def test_format_file_no_formatter_configured(tmp_path: Path): + """Test that format_file returns None when no formatter is configured for the extension.""" + test_file = tmp_path / "test.md" + test_file.write_text("# Test\n") + + # Explicitly set formatter_command to None to test the "no formatter" case + config = BasicMemoryConfig(format_on_save=True, formatter_command=None) + + result = await format_file(test_file, config) + assert result is None + + +@pytest.mark.asyncio +async def test_format_file_with_global_formatter(tmp_path: Path): + """Test formatting with global formatter_command.""" + test_file = tmp_path / "test.md" + original_content = "# Test\n" + test_file.write_text(original_content) + + # Use a simple formatter that just echoes content (cat) + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="cat {file}", # This doesn't modify the file but runs successfully + ) + + result = await format_file(test_file, config) + assert result == original_content + + +@pytest.mark.asyncio +async def test_format_file_with_extension_specific_formatter(tmp_path: Path): + """Test formatting with extension-specific formatter.""" + test_file = tmp_path / "test.json" + original_content = '{"key": "value"}' + test_file.write_text(original_content) + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="echo global", # This should NOT be used + formatters={"json": "cat {file}"}, # Extension-specific should be used + ) + + result = await format_file(test_file, config) + assert result == original_content + + +@pytest.mark.asyncio +async def test_format_file_extension_specific_overrides_global(tmp_path: Path): + """Test that extension-specific formatter takes precedence over global.""" + test_file = tmp_path / "test.md" + original_content = "# Test\n" + test_file.write_text(original_content) + + # Use different commands to verify which one is used + # Since cat just reads the file, we can tell which was used by the content + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="cat /dev/null", # Would return empty + formatters={"md": "cat {file}"}, # Should return original content + ) + + result = await format_file(test_file, config) + assert result == original_content + + +@pytest.mark.asyncio +async def test_format_file_falls_back_to_global(tmp_path: Path): + """Test that global formatter is used when no extension-specific one exists.""" + test_file = tmp_path / "test.txt" # No extension-specific formatter for .txt + original_content = "Some text\n" + test_file.write_text(original_content) + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="cat {file}", + formatters={"md": "echo wrong"}, # Only for .md, not .txt + ) + + result = await format_file(test_file, config) + assert result == original_content + + +@pytest.mark.asyncio +async def test_format_file_handles_nonexistent_formatter(tmp_path: Path): + """Test that format_file handles missing formatter executable gracefully.""" + test_file = tmp_path / "test.md" + test_file.write_text("# Test\n") + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="nonexistent_formatter_executable_12345 {file}", + ) + + result = await format_file(test_file, config) + assert result is None # Should return None on error + + +@pytest.mark.asyncio +async def test_format_file_handles_timeout(tmp_path: Path): + """Test that format_file handles formatter timeout gracefully.""" + test_file = tmp_path / "test.md" + test_file.write_text("# Test\n") + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="sleep 10", # Will timeout + formatter_timeout=0.1, # Very short timeout + ) + + result = await format_file(test_file, config) + assert result is None # Should return None on timeout + + +@pytest.mark.asyncio +async def test_format_file_handles_nonzero_exit(tmp_path: Path): + """Test that format_file handles non-zero exit codes gracefully.""" + test_file = tmp_path / "test.md" + original_content = "# Test\n" + test_file.write_text(original_content) + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="sh -c 'exit 1'", # Non-zero exit + ) + + result = await format_file(test_file, config) + # Should still return file content even with non-zero exit + assert result == original_content + + +@pytest.mark.asyncio +async def test_format_file_returns_modified_content(tmp_path: Path): + """Test that format_file returns the modified file content after formatting.""" + test_file = tmp_path / "test.md" + original_content = "original content" + test_file.write_text(original_content) + + # This formatter modifies the file to contain different content + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="sh -c 'echo modified > {file}'", + ) + + result = await format_file(test_file, config) + assert result == "modified\n" + assert test_file.read_text() == "modified\n" + + +@pytest.mark.asyncio +async def test_format_file_with_spaces_in_path(tmp_path: Path): + """Test formatting files with spaces in path.""" + subdir = tmp_path / "path with spaces" + subdir.mkdir() + test_file = subdir / "my file.md" + original_content = "# Test\n" + test_file.write_text(original_content) + + config = BasicMemoryConfig( + format_on_save=True, + formatter_command="cat {file}", + ) + + result = await format_file(test_file, config) + assert result == original_content