diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index abe519558..6214c4227 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -2,7 +2,6 @@ import os from collections import Counter, defaultdict -from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from enum import IntEnum, auto, unique from functools import lru_cache @@ -16,6 +15,7 @@ from codegen.sdk.codebase.config_parser import ConfigParser, get_config_parser_for_language from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite from codegen.sdk.codebase.flagging.flags import Flags +from codegen.sdk.codebase.io.file_io import FileIO from codegen.sdk.codebase.transaction_manager import TransactionManager from codegen.sdk.codebase.validation import get_edges, post_reset_validation from codegen.sdk.core.autocommit import AutoCommit, commiter @@ -23,7 +23,6 @@ from codegen.sdk.core.external.dependency_manager import DependencyManager, get_dependency_manager from codegen.sdk.core.external.language_engine import LanguageEngine, get_language_engine from codegen.sdk.enums import Edge, EdgeType, NodeType, ProgrammingLanguage -from codegen.sdk.extensions.io import write_changes from codegen.sdk.extensions.sort import sort_editables from codegen.sdk.extensions.utils import uncache_all from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify @@ -37,6 +36,7 @@ from git import Commit as GitCommit from codegen.git.repo_operator.repo_operator import RepoOperator + from codegen.sdk.codebase.io.io import IO from codegen.sdk.codebase.node_classes.node_classes import NodeClasses from codegen.sdk.core.dataclasses.usage import Usage from codegen.sdk.core.expressions import Expression @@ -92,7 +92,6 @@ class CodebaseContext: pending_syncs: list[DiffLite] # Diffs that have been applied to disk, but not the graph (to be used for sync graph) all_syncs: list[DiffLite] # All diffs that have been applied to the graph (to be used for graph reset) _autocommit: AutoCommit - pending_files: set[SourceFile] generation: int parser: Parser[Expression] synced_commit: GitCommit | None @@ -110,6 +109,7 @@ class CodebaseContext: session_options: SessionOptions = SessionOptions() projects: list[ProjectConfig] unapplied_diffs: list[DiffLite] + io: IO def __init__( self, @@ -134,6 +134,7 @@ def __init__( # =====[ __init__ attributes ]===== self.projects = projects + self.io = FileIO() context = projects[0] self.node_classes = get_node_classes(context.programming_language) self.config = config @@ -165,7 +166,6 @@ def __init__( self.pending_syncs = [] self.all_syncs = [] self.unapplied_diffs = [] - self.pending_files = set() self.flags = Flags() def __repr__(self): @@ -259,7 +259,13 @@ def _reset_files(self, syncs: list[DiffLite]) -> None: files_to_remove.append(sync.path) modified_files.add(sync.path) logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files") - write_changes(files_to_remove, files_to_write) + for file in files_to_remove: + self.io.delete_file(file) + to_save = set() + for file, content in files_to_write: + self.io.write_file(file, content) + to_save.add(file) + self.io.save_files(to_save) @stopwatch def reset_codebase(self) -> None: @@ -270,7 +276,7 @@ def reset_codebase(self) -> None: def undo_applied_diffs(self) -> None: self.transaction_manager.clear_transactions() self.reset_codebase() - self.check_changes() + self.io.check_changes() self.pending_syncs.clear() # Discard pending changes if len(self.all_syncs) > 0: logger.info(f"Unapplying {len(self.all_syncs)} diffs to graph. Current graph commit: {self.synced_commit}") @@ -432,7 +438,7 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr # Step 5: Add new files as nodes to graph (does not yet add edges) for filepath in files_to_sync[SyncType.ADD]: - content = filepath.read_text(errors="ignore") + content = self.io.read_text(filepath) # TODO: this is wrong with context changes if filepath.suffix in self.extensions: file_cls = self.node_classes.file_cls @@ -634,17 +640,6 @@ def remove_edge(self, u: NodeId, v: NodeId, *, edge_type: EdgeType | None = None continue self._graph.remove_edge_from_index(edge) - def check_changes(self) -> None: - for file in self.pending_files: - file.check_changes() - self.pending_files.clear() - - def write_files(self, files: set[Path] | None = None) -> None: - to_write = set(filter(lambda f: f.filepath in files, self.pending_files)) if files is not None else self.pending_files - with ThreadPoolExecutor() as exec: - exec.map(lambda f: f.write_pending_content(), to_write) - self.pending_files.difference_update(to_write) - @lru_cache(maxsize=10000) def to_absolute(self, filepath: PathLike | str) -> Path: path = Path(filepath) @@ -684,7 +679,7 @@ def commit_transactions(self, sync_graph: bool = True, sync_file: bool = True, f # Write files if requested if sync_file: - self.write_files(files) + self.io.save_files(files) # Sync the graph if requested if sync_graph and len(self.pending_syncs) > 0: diff --git a/src/codegen/sdk/codebase/io/file_io.py b/src/codegen/sdk/codebase/io/file_io.py new file mode 100644 index 000000000..081be4f22 --- /dev/null +++ b/src/codegen/sdk/codebase/io/file_io.py @@ -0,0 +1,51 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from codegen.sdk.codebase.io.io import IO, BadWriteError + +logger = logging.getLogger(__name__) + + +class FileIO(IO): + """IO implementation that writes files to disk, and tracks pending changes.""" + + files: dict[Path, bytes] + + def __init__(self): + self.files = {} + + def write_bytes(self, path: Path, content: bytes) -> None: + self.files[path] = content + + def read_bytes(self, path: Path) -> bytes: + if path in self.files: + return self.files[path] + else: + return path.read_bytes() + + def save_files(self, files: set[Path] | None = None) -> None: + to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys() + with ThreadPoolExecutor() as exec: + exec.map(lambda path: path.write_bytes(self.files[path]), to_save) + if files is None: + self.files.clear() + else: + for path in to_save: + del self.files[path] + + def check_changes(self) -> None: + if self.files: + logger.error(BadWriteError("Directly called file write without calling commit_transactions")) + self.files.clear() + + def delete_file(self, path: Path) -> None: + self.untrack_file(path) + if path.exists(): + path.unlink() + + def untrack_file(self, path: Path) -> None: + self.files.pop(path, None) + + def file_exists(self, path: Path) -> bool: + return path.exists() diff --git a/src/codegen/sdk/codebase/io/io.py b/src/codegen/sdk/codebase/io/io.py new file mode 100644 index 000000000..3321f072b --- /dev/null +++ b/src/codegen/sdk/codebase/io/io.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from pathlib import Path + + +class BadWriteError(Exception): + pass + + +class IO(ABC): + def write_file(self, path: Path, content: str | bytes | None) -> None: + if content is None: + self.untrack_file(path) + elif isinstance(content, str): + self.write_text(path, content) + else: + self.write_bytes(path, content) + + def write_text(self, path: Path, content: str) -> None: + self.write_bytes(path, content.encode("utf-8")) + + @abstractmethod + def untrack_file(self, path: Path) -> None: + pass + + @abstractmethod + def write_bytes(self, path: Path, content: bytes) -> None: + pass + + @abstractmethod + def read_bytes(self, path: Path) -> bytes: + pass + + def read_text(self, path: Path) -> str: + return self.read_bytes(path).decode("utf-8") + + @abstractmethod + def save_files(self, files: set[Path] | None = None) -> None: + pass + + @abstractmethod + def check_changes(self) -> None: + pass + + @abstractmethod + def delete_file(self, path: Path) -> None: + pass + + @abstractmethod + def file_exists(self, path: Path) -> bool: + pass diff --git a/src/codegen/sdk/codebase/transactions.py b/src/codegen/sdk/codebase/transactions.py index d2e71f8fa..31d48b2e1 100644 --- a/src/codegen/sdk/codebase/transactions.py +++ b/src/codegen/sdk/codebase/transactions.py @@ -1,4 +1,3 @@ -import os from collections.abc import Callable from difflib import unified_diff from enum import IntEnum @@ -267,7 +266,7 @@ def __init__( def execute(self) -> None: """Renames the file""" - self.file.write_pending_content() + self.file.ctx.io.save_files({self.file.path}) self.file_path.rename(self.new_file_path) def get_diff(self) -> DiffLite: @@ -292,8 +291,7 @@ def __init__( def execute(self) -> None: """Removes the file""" - os.remove(self.file_path) - self.file._pending_content_bytes = None + self.file.ctx.io.delete_file(self.file.path) def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 0324c05ed..ca2a6087b 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -481,17 +481,17 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = def get_file_from_path(path: Path) -> File | None: try: - return File.from_content(path, path.read_text(), self.ctx, sync=False) + return File.from_content(path, self.ctx.io.read_text(path), self.ctx, sync=False) except UnicodeDecodeError: # Handle when file is a binary file - return File.from_content(path, path.read_bytes(), self.ctx, sync=False, binary=True) + return File.from_content(path, self.ctx.io.read_bytes(path), self.ctx, sync=False, binary=True) # Try to get the file from the graph first file = self.ctx.get_file(filepath, ignore_case=ignore_case) if file is not None: return file absolute_path = self.ctx.to_absolute(filepath) - if absolute_path.exists(): + if self.ctx.io.file_exists(absolute_path): return get_file_from_path(absolute_path) elif ignore_case: parent = absolute_path.parent diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 67c2faf20..540800190 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override from tree_sitter import Node as TSNode +from typing_extensions import deprecated from codegen.sdk._proxy import proxy_property from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -45,10 +46,6 @@ logger = logging.getLogger(__name__) -class BadWriteError(Exception): - pass - - @apidoc class File(Editable[None]): """Represents a generic file. @@ -66,7 +63,6 @@ class File(Editable[None]): file_path: str path: Path node_type: Literal[NodeType.FILE] = NodeType.FILE - _pending_content_bytes: bytes | None = None _directory: Directory | None _pending_imports: set[str] _binary: bool = False @@ -117,10 +113,8 @@ def from_content(cls, filepath: str | Path, content: str | bytes, ctx: CodebaseC if not path.exists(): update_graph = True path.parent.mkdir(parents=True, exist_ok=True) - if not binary: - path.write_text(content) - else: - path.write_bytes(content) + ctx.io.write_file(path, content) + ctx.io.save_files({path}) new_file = cls(filepath, ctx, ts_node=None, binary=binary) return new_file @@ -133,10 +127,7 @@ def content_bytes(self) -> bytes: TODO: move rest of graph sitter to operate in bytes to prevent multi byte character issues? """ - # Check against None due to possibility of empty byte - if self._pending_content_bytes is None: - return self.path.read_bytes() - return self._pending_content_bytes + return self.ctx.io.read_bytes(self.path) @property @reader @@ -162,31 +153,18 @@ def content(self) -> str: @noapidoc def write(self, content: str | bytes, to_disk: bool = False) -> None: - """Writes string contents to the file.""" - self.write_bytes(content.encode("utf-8") if isinstance(content, str) else content, to_disk=to_disk) - - @noapidoc - def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None: - self._pending_content_bytes = content_bytes - self.ctx.pending_files.add(self) + """Writes contents to the file.""" + self.ctx.io.write_file(self.path, content) if to_disk: - self.write_pending_content() + self.ctx.io.save_files({self.path}) if self.ts_node.start_byte == self.ts_node.end_byte: # TS didn't parse anything, register a write to make sure the transaction manager can restore the file later. self.edit("") @noapidoc - def write_pending_content(self) -> None: - if self._pending_content_bytes is not None: - self.path.write_bytes(self._pending_content_bytes) - self._pending_content_bytes = None - logger.debug("Finished write_pending_content") - - @noapidoc - @writer - def check_changes(self) -> None: - if self._pending_content_bytes is not None: - logger.error(BadWriteError("Directly called file write without calling commit_transactions")) + @deprecated("Use write instead") + def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None: + self.write(content_bytes, to_disk=to_disk) @property @reader @@ -272,7 +250,7 @@ def remove(self) -> None: None """ self.transaction_manager.add_file_remove_transaction(self) - self._pending_content_bytes = None + self.ctx.io.write_file(self.path, None) @property def filepath(self) -> str: @@ -596,10 +574,11 @@ def from_content(cls, filepath: str | PathLike | Path, content: str, ctx: Codeba return None update_graph = False - if not path.exists(): + if not ctx.io.file_exists(path): update_graph = True path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content) + ctx.io.write_file(path, content) + ctx.io.save_files({path}) if update_graph and sync: ctx.add_single_file(path) diff --git a/src/codegen/sdk/extensions/io.pyx b/src/codegen/sdk/extensions/io.pyx deleted file mode 100644 index 6b2c8fcd8..000000000 --- a/src/codegen/sdk/extensions/io.pyx +++ /dev/null @@ -1,12 +0,0 @@ -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor -import os - - -def write_changes(files_to_remove: list[Path], files_to_write: list[tuple[Path, bytes]]): - # Start at the oldest sync and then apply non-conflicting newer changes - with ThreadPoolExecutor() as executor: - for file_to_remove in files_to_remove: - executor.submit(os.remove, file_to_remove) - for file_to_write, content in files_to_write: - executor.submit(file_to_write.write_bytes, content) diff --git a/tests/unit/codegen/sdk/io/test_file_io.py b/tests/unit/codegen/sdk/io/test_file_io.py new file mode 100644 index 000000000..6f8148e4f --- /dev/null +++ b/tests/unit/codegen/sdk/io/test_file_io.py @@ -0,0 +1,63 @@ +import pytest + +from codegen.sdk.codebase.io.file_io import FileIO + + +@pytest.fixture +def file_io(): + return FileIO() + + +def test_write_and_read_bytes(file_io, tmp_path): + test_file = tmp_path / "test.txt" + content = b"test content" + + file_io.write_bytes(test_file, content) + assert file_io.read_bytes(test_file) == content + assert not test_file.exists() + + +def test_read_bytes_nonexistent_file(file_io, tmp_path): + test_file = tmp_path / "test.txt" + test_content = b"test content" + + # Create a real file + test_file.write_bytes(test_content) + + # Reading should load from disk + assert file_io.read_bytes(test_file) == test_content + + +def test_save_file(file_io, tmp_path): + test_file = tmp_path / "test.txt" + content = b"test content" + + file_io.write_bytes(test_file, content) + file_io.save_files({test_file}) + + assert test_file.exists() + assert test_file.read_bytes() == content + + +def test_check_changes_with_pending_changes(file_io, tmp_path, caplog): + test_file = tmp_path / "test.txt" + file_io.write_bytes(test_file, b"test content") + + file_io.check_changes() + + assert "Directly called file write without calling commit_transactions" in caplog.text + + +def test_check_changes_no_pending_changes(file_io): + file_io.check_changes() # Should not raise any exception + + +def test_delete_file(file_io, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_bytes(b"test content") + + file_io.read_bytes(test_file) # Load into memory + file_io.delete_file(test_file) + + assert not test_file.exists() + assert test_file not in file_io.files