diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index 6e361ef1e..53ba72040 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -159,7 +159,6 @@ def __init__( # =====[ __init__ attributes ]===== self.projects = projects - self.io = io or FileIO() context = projects[0] self.node_classes = get_node_classes(context.programming_language) self.config = config or CodebaseConfig() @@ -169,6 +168,11 @@ def __init__( self.full_path = os.path.join(self.repo_path, context.base_path) if context.base_path else self.repo_path self.codeowners_parser = context.repo_operator.codeowners_parser self.base_url = context.repo_operator.base_url + if not self.config.allow_external: + # TODO: Fix this to be more robust with multiple projects + self.io = io or FileIO(allowed_paths=[Path(self.repo_path).resolve()]) + else: + self.io = io or FileIO() # =====[ computed attributes ]===== self.transaction_manager = TransactionManager() self._autocommit = AutoCommit(self) @@ -188,6 +192,13 @@ def __init__( logger.warning("WARNING: The codebase is using an unsupported language!") logger.warning("Some features may not work as expected. Advanced static analysis will be disabled but simple file IO will still work.") + # Assert config assertions + # External import resolution must be enabled if syspath is enabled + if self.config.py_resolve_syspath: + if not self.config.allow_external: + msg = "allow_external must be set to True when py_resolve_syspath is enabled" + raise ValueError(msg) + # Build the graph if not self.config.exp_lazy_graph: self.build_graph(context.repo_operator) diff --git a/src/codegen/sdk/codebase/io/file_io.py b/src/codegen/sdk/codebase/io/file_io.py index a0ced7e41..f59a28851 100644 --- a/src/codegen/sdk/codebase/io/file_io.py +++ b/src/codegen/sdk/codebase/io/file_io.py @@ -11,14 +11,24 @@ class FileIO(IO): """IO implementation that writes files to disk, and tracks pending changes.""" files: dict[Path, bytes] + allowed_paths: list[Path] | None - def __init__(self): + def __init__(self, allowed_paths: list[Path] | None = None): self.files = {} + self.allowed_paths = allowed_paths + + def _verify_path(self, path: Path) -> None: + if self.allowed_paths is not None: + if not any(path.resolve().is_relative_to(p.resolve()) for p in self.allowed_paths): + msg = f"Path {path.resolve()} is not within allowed paths {self.allowed_paths}" + raise BadWriteError(msg) def write_bytes(self, path: Path, content: bytes) -> None: + self._verify_path(path) self.files[path] = content def read_bytes(self, path: Path) -> bytes: + self._verify_path(path) if path in self.files: return self.files[path] else: @@ -26,6 +36,8 @@ def read_bytes(self, path: Path) -> bytes: def save_files(self, files: set[Path] | None = None) -> None: to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys() + for path in to_save: + self._verify_path(path) with ThreadPoolExecutor() as exec: exec.map(lambda path: path.write_bytes(self.files[path]), to_save) if files is None: @@ -40,12 +52,15 @@ def check_changes(self) -> None: self.files.clear() def delete_file(self, path: Path) -> None: + self._verify_path(path) self.untrack_file(path) if path.exists(): path.unlink() def untrack_file(self, path: Path) -> None: + self._verify_path(path) self.files.pop(path, None) def file_exists(self, path: Path) -> bool: + self._verify_path(path) return path.exists() diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 9d6eaf2fd..c8a729741 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -213,13 +213,6 @@ def __init__( self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress) self.console = Console(record=True, soft_wrap=True) - # Assert config assertions - # External import resolution must be enabled if syspath is enabled - if self.ctx.config.py_resolve_syspath: - if not self.ctx.config.allow_external: - msg = "allow_external must be set to True when py_resolve_syspath is enabled" - raise ValueError(msg) - @noapidoc def __str__(self) -> str: return f"" diff --git a/tests/unit/codegen/sdk/io/test_file_io.py b/tests/unit/codegen/sdk/io/test_file_io.py index 6f8148e4f..bb627da5d 100644 --- a/tests/unit/codegen/sdk/io/test_file_io.py +++ b/tests/unit/codegen/sdk/io/test_file_io.py @@ -1,6 +1,6 @@ import pytest -from codegen.sdk.codebase.io.file_io import FileIO +from codegen.sdk.codebase.io.file_io import BadWriteError, FileIO @pytest.fixture @@ -61,3 +61,102 @@ def test_delete_file(file_io, tmp_path): assert not test_file.exists() assert test_file not in file_io.files + + +def test_read_and_write_bounded(file_io, tmp_path): + allowed_dir = tmp_path / "allowed" + file_io.allowed_paths = [allowed_dir] + + allowed_file = allowed_dir / "test.txt" + content = b"test content" + + file_io.write_bytes(allowed_file, content) + assert file_io.read_bytes(allowed_file) == content + + with pytest.raises(BadWriteError) as exc_info: + bad_file = tmp_path / "test.txt" + file_io.write_bytes(bad_file, content) + + assert "is not within allowed paths" in str(exc_info.value) + + with pytest.raises(BadWriteError) as exc_info: + bad_file_2 = allowed_dir / ".." / "test2.txt" + file_io.write_bytes(bad_file_2, content) + + assert "is not within allowed paths" in str(exc_info.value) + + +def test_read_bounded(file_io, tmp_path): + allowed_dir = tmp_path / "allowed" + allowed_dir.mkdir(exist_ok=True) + file_io.allowed_paths = [allowed_dir] + + allowed_file = allowed_dir / "test.txt" + content = b"test content" + allowed_file.write_bytes(content) + + assert file_io.read_bytes(allowed_file) == content + + with pytest.raises(BadWriteError) as exc_info: + bad_file = tmp_path / "test.txt" + bad_file.write_bytes(content) + file_io.read_bytes(bad_file) + + assert "is not within allowed paths" in str(exc_info.value) + + with pytest.raises(BadWriteError) as exc_info: + bad_file_2 = allowed_dir / ".." / "test2.txt" + bad_file_2.write_bytes(content) + file_io.read_bytes(bad_file_2) + + assert "is not within allowed paths" in str(exc_info.value) + + +def test_delete_file_bounded(file_io, tmp_path): + allowed_dir = tmp_path / "allowed" + allowed_dir.mkdir(exist_ok=True) + file_io.allowed_paths = [allowed_dir] + + allowed_file = allowed_dir / "test.txt" + allowed_file.write_bytes(b"test content") + + file_io.delete_file(allowed_file) + + with pytest.raises(BadWriteError) as exc_info: + bad_file = tmp_path / "test.txt" + bad_file.write_bytes(b"test content") + file_io.delete_file(bad_file) + + assert "is not within allowed paths" in str(exc_info.value) + + with pytest.raises(BadWriteError) as exc_info: + bad_file_2 = allowed_dir / ".." / "test2.txt" + bad_file_2.write_bytes(b"test content") + file_io.delete_file(bad_file_2) + + assert "is not within allowed paths" in str(exc_info.value) + + +def test_file_exists_bounded(file_io, tmp_path): + allowed_dir = tmp_path / "allowed" + allowed_dir.mkdir(exist_ok=True) + file_io.allowed_paths = [allowed_dir] + + allowed_file = allowed_dir / "test.txt" + allowed_file.write_bytes(b"test content") + + assert file_io.file_exists(allowed_file) + + with pytest.raises(BadWriteError) as exc_info: + bad_file = tmp_path / "test.txt" + bad_file.write_bytes(b"test content") + file_io.file_exists(bad_file) + + assert "is not within allowed paths" in str(exc_info.value) + + with pytest.raises(BadWriteError) as exc_info: + bad_file_2 = allowed_dir / ".." / "test2.txt" + bad_file_2.write_bytes(b"test content") + file_io.file_exists(bad_file_2) + + assert "is not within allowed paths" in str(exc_info.value)