Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/codegen/sdk/codebase/codebase_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(

# =====[ __init__ attributes ]=====
self.projects = projects
self.io = io or FileIO()
context = projects[0]
self.node_classes = get_node_classes(context.programming_language)
self.config = config or CodebaseConfig()
Expand All @@ -169,6 +168,11 @@ def __init__(
self.full_path = os.path.join(self.repo_path, context.base_path) if context.base_path else self.repo_path
self.codeowners_parser = context.repo_operator.codeowners_parser
self.base_url = context.repo_operator.base_url
if not self.config.allow_external:
# TODO: Fix this to be more robust with multiple projects
self.io = io or FileIO(allowed_paths=[Path(self.repo_path).resolve()])
else:
self.io = io or FileIO()
# =====[ computed attributes ]=====
self.transaction_manager = TransactionManager()
self._autocommit = AutoCommit(self)
Expand All @@ -188,6 +192,13 @@ def __init__(
logger.warning("WARNING: The codebase is using an unsupported language!")
logger.warning("Some features may not work as expected. Advanced static analysis will be disabled but simple file IO will still work.")

# Assert config assertions
# External import resolution must be enabled if syspath is enabled
if self.config.py_resolve_syspath:
if not self.config.allow_external:
msg = "allow_external must be set to True when py_resolve_syspath is enabled"
raise ValueError(msg)

# Build the graph
if not self.config.exp_lazy_graph:
self.build_graph(context.repo_operator)
Expand Down
17 changes: 16 additions & 1 deletion src/codegen/sdk/codebase/io/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,33 @@ class FileIO(IO):
"""IO implementation that writes files to disk, and tracks pending changes."""

files: dict[Path, bytes]
allowed_paths: list[Path] | None

def __init__(self):
def __init__(self, allowed_paths: list[Path] | None = None):
self.files = {}
self.allowed_paths = allowed_paths

def _verify_path(self, path: Path) -> None:
if self.allowed_paths is not None:
if not any(path.resolve().is_relative_to(p.resolve()) for p in self.allowed_paths):
msg = f"Path {path.resolve()} is not within allowed paths {self.allowed_paths}"
raise BadWriteError(msg)

def write_bytes(self, path: Path, content: bytes) -> None:
self._verify_path(path)
self.files[path] = content

def read_bytes(self, path: Path) -> bytes:
self._verify_path(path)
if path in self.files:
return self.files[path]
else:
return path.read_bytes()

def save_files(self, files: set[Path] | None = None) -> None:
to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys()
for path in to_save:
self._verify_path(path)
with ThreadPoolExecutor() as exec:
exec.map(lambda path: path.write_bytes(self.files[path]), to_save)
if files is None:
Expand All @@ -40,12 +52,15 @@ def check_changes(self) -> None:
self.files.clear()

def delete_file(self, path: Path) -> None:
self._verify_path(path)
self.untrack_file(path)
if path.exists():
path.unlink()

def untrack_file(self, path: Path) -> None:
self._verify_path(path)
self.files.pop(path, None)

def file_exists(self, path: Path) -> bool:
self._verify_path(path)
return path.exists()
7 changes: 0 additions & 7 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ def __init__(
self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress)
self.console = Console(record=True, soft_wrap=True)

# Assert config assertions
# External import resolution must be enabled if syspath is enabled
if self.ctx.config.py_resolve_syspath:
if not self.ctx.config.allow_external:
msg = "allow_external must be set to True when py_resolve_syspath is enabled"
raise ValueError(msg)

@noapidoc
def __str__(self) -> str:
return f"<Codebase(name={self.name}, language={self.language}, path={self.repo_path})>"
Expand Down
101 changes: 100 additions & 1 deletion tests/unit/codegen/sdk/io/test_file_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from codegen.sdk.codebase.io.file_io import FileIO
from codegen.sdk.codebase.io.file_io import BadWriteError, FileIO


@pytest.fixture
Expand Down Expand Up @@ -61,3 +61,102 @@ def test_delete_file(file_io, tmp_path):

assert not test_file.exists()
assert test_file not in file_io.files


def test_read_and_write_bounded(file_io, tmp_path):
allowed_dir = tmp_path / "allowed"
file_io.allowed_paths = [allowed_dir]

allowed_file = allowed_dir / "test.txt"
content = b"test content"

file_io.write_bytes(allowed_file, content)
assert file_io.read_bytes(allowed_file) == content

with pytest.raises(BadWriteError) as exc_info:
bad_file = tmp_path / "test.txt"
file_io.write_bytes(bad_file, content)

assert "is not within allowed paths" in str(exc_info.value)

with pytest.raises(BadWriteError) as exc_info:
bad_file_2 = allowed_dir / ".." / "test2.txt"
file_io.write_bytes(bad_file_2, content)

assert "is not within allowed paths" in str(exc_info.value)


def test_read_bounded(file_io, tmp_path):
allowed_dir = tmp_path / "allowed"
allowed_dir.mkdir(exist_ok=True)
file_io.allowed_paths = [allowed_dir]

allowed_file = allowed_dir / "test.txt"
content = b"test content"
allowed_file.write_bytes(content)

assert file_io.read_bytes(allowed_file) == content

with pytest.raises(BadWriteError) as exc_info:
bad_file = tmp_path / "test.txt"
bad_file.write_bytes(content)
file_io.read_bytes(bad_file)

assert "is not within allowed paths" in str(exc_info.value)

with pytest.raises(BadWriteError) as exc_info:
bad_file_2 = allowed_dir / ".." / "test2.txt"
bad_file_2.write_bytes(content)
file_io.read_bytes(bad_file_2)

assert "is not within allowed paths" in str(exc_info.value)


def test_delete_file_bounded(file_io, tmp_path):
allowed_dir = tmp_path / "allowed"
allowed_dir.mkdir(exist_ok=True)
file_io.allowed_paths = [allowed_dir]

allowed_file = allowed_dir / "test.txt"
allowed_file.write_bytes(b"test content")

file_io.delete_file(allowed_file)

with pytest.raises(BadWriteError) as exc_info:
bad_file = tmp_path / "test.txt"
bad_file.write_bytes(b"test content")
file_io.delete_file(bad_file)

assert "is not within allowed paths" in str(exc_info.value)

with pytest.raises(BadWriteError) as exc_info:
bad_file_2 = allowed_dir / ".." / "test2.txt"
bad_file_2.write_bytes(b"test content")
file_io.delete_file(bad_file_2)

assert "is not within allowed paths" in str(exc_info.value)


def test_file_exists_bounded(file_io, tmp_path):
allowed_dir = tmp_path / "allowed"
allowed_dir.mkdir(exist_ok=True)
file_io.allowed_paths = [allowed_dir]

allowed_file = allowed_dir / "test.txt"
allowed_file.write_bytes(b"test content")

assert file_io.file_exists(allowed_file)

with pytest.raises(BadWriteError) as exc_info:
bad_file = tmp_path / "test.txt"
bad_file.write_bytes(b"test content")
file_io.file_exists(bad_file)

assert "is not within allowed paths" in str(exc_info.value)

with pytest.raises(BadWriteError) as exc_info:
bad_file_2 = allowed_dir / ".." / "test2.txt"
bad_file_2.write_bytes(b"test content")
file_io.file_exists(bad_file_2)

assert "is not within allowed paths" in str(exc_info.value)
Loading