Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 144 additions & 59 deletions dreadnode/agent/tools/fs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
import asyncio
import re
import typing as t
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path

import aiofiles
import rigging as rg
from fsspec import AbstractFileSystem # type: ignore[import-untyped]
from loguru import logger
from pydantic import PrivateAttr
from upath import UPath

Expand Down Expand Up @@ -43,7 +45,9 @@
type="file",
name=relative,
size=path.stat().st_size,
modified=datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).strftime(
modified=datetime.fromtimestamp(
path.stat().st_mtime, tz=timezone.utc
).strftime(
"%Y-%m-%d %H:%M:%S",
),
)
Expand All @@ -68,6 +72,8 @@
"""Extra options for the universal filesystem."""
multi_modal: bool = Config(default=False)
"""Enable returning non-text context like images."""
max_concurrent_reads: int = Config(default=25)
"""Maximum number of concurrent file reads for grep operations."""

variant: t.Literal["read", "write"] = Config(default="read")

Expand All @@ -94,15 +100,26 @@

return full_path

def _safe_create_file(self, path: str) -> "UPath":
async def _safe_create_file(self, path: str) -> "UPath":
"""
Safely create a file and its parent directories if they don't exist.

Args:
path: Path to the file to create

Returns:
UPath: The resolved path to the created file
"""
file_path = self._resolve(path)

parent_path = file_path.parent
if not parent_path.exists():
parent_path.mkdir(parents=True, exist_ok=True)
await asyncio.to_thread(
lambda: parent_path.mkdir(parents=True, exist_ok=True)
)

if not file_path.exists():
file_path.touch()
await asyncio.to_thread(file_path.touch)

return file_path

Expand All @@ -116,23 +133,34 @@
return full_path[len(base_path) :]

@tool_method(variants=["read", "write"], catch=True)
def read_file(
async def read_file(
self,
path: t.Annotated[str, "Path to the file to read"],
) -> rg.ContentImageUrl | str:
"""Read a file and return its contents."""
) -> rg.ContentImageUrl | str | bytes:
"""
Read a file and return its contents.

Returns:
- str: The file contents decoded as UTF-8 if possible.
- rg.ContentImageUrl: If the file is non-text and multi_modal is True.
- bytes: If the file cannot be decoded as UTF-8 and multi_modal is False.

Note:
Callers should be prepared to handle raw bytes if the file is not valid UTF-8 and multi_modal is False.
"""
_path = self._resolve(path)
content = _path.read_bytes()
async with aiofiles.open(_path, "rb") as f:
content = await f.read()

try:
return content.decode("utf-8")
except UnicodeDecodeError as e:
except UnicodeDecodeError:
if self.multi_modal:
return rg.ContentImageUrl.from_file(path)
raise ValueError("File is not a valid text file.") from e
return content

@tool_method(variants=["read", "write"], catch=True)
def read_lines(
async def read_lines(
self,
path: t.Annotated[str, "Path to the file to read"],
start_line: t.Annotated[int, "Start line number (0-indexed)"] = 0,
Expand All @@ -150,8 +178,8 @@
if not _path.is_file():
raise ValueError(f"'{path}' is not a file.")

with _path.open("r") as f:
lines = f.readlines()
async with aiofiles.open(_path, "r") as f:

Check failure on line 181 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.13)

Ruff (UP015)

dreadnode/agent/tools/fs.py:181:41: UP015 Unnecessary mode argument

Check failure on line 181 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.12)

Ruff (UP015)

dreadnode/agent/tools/fs.py:181:41: UP015 Unnecessary mode argument

Check failure on line 181 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.11)

Ruff (UP015)

dreadnode/agent/tools/fs.py:181:41: UP015 Unnecessary mode argument

Check failure on line 181 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.10)

Ruff (UP015)

dreadnode/agent/tools/fs.py:181:41: UP015 Unnecessary mode argument
lines = await f.readlines()

if start_line < 0:
start_line = len(lines) + start_line
Expand All @@ -165,7 +193,7 @@
return "\n".join(lines[start_line:end_line])

@tool_method(variants=["read", "write"], catch=True)
def ls(
async def ls(
self,
path: t.Annotated[str, "Directory path to list"] = "",
) -> list[FilesystemItem]:
Expand All @@ -178,19 +206,19 @@
if not _path.is_dir():
raise ValueError(f"'{path}' is not a directory.")

items = list(_path.iterdir())
items = await asyncio.to_thread(lambda: list(_path.iterdir()))
return [FilesystemItem.from_path(item, self._upath) for item in items]

@tool_method(catch=True)
def glob(
async def glob(
self,
pattern: t.Annotated[str, "Glob pattern for file matching"],
) -> list[FilesystemItem]:
"""
Returns a list of paths matching a valid glob pattern. The pattern can
include ** for recursive matching, such as '/path/**/dir/*.py'.
"""
matches = list(self._upath.glob(pattern))
matches = await asyncio.to_thread(lambda: list(self._upath.glob(pattern)))

# Check to make sure all matches are within the base path
for match in matches:
Expand All @@ -200,7 +228,7 @@
return [FilesystemItem.from_path(match, self._upath) for match in matches]

@tool_method(variants=["read", "write"], catch=True)
def grep(
async def grep(
self,
pattern: t.Annotated[str, "Regular expression pattern to search for"],
path: t.Annotated[str, "File or directory path to search in"],
Expand All @@ -225,25 +253,28 @@
files_to_search.append(target_path)
elif target_path.is_dir():
files_to_search.extend(
list(target_path.rglob("*") if recursive else target_path.glob("*")),
await asyncio.to_thread(
lambda: list(
target_path.rglob("*") if recursive else target_path.glob("*")
)
),
)

matches: list[GrepMatch] = []
for file_path in [f for f in files_to_search if f.is_file()]:
if len(matches) >= max_results:
break

if file_path.stat().st_size > MAX_GREP_FILE_SIZE:
continue
# Filter to files only and check size
files_to_search = [
f
for f in files_to_search
if f.is_file() and f.stat().st_size <= MAX_GREP_FILE_SIZE
]

with contextlib.suppress(Exception):
with file_path.open("r") as f:
lines = f.readlines()
async def search_file(file_path: UPath) -> list[GrepMatch]:
"""Search a single file for matches."""
file_matches: list[GrepMatch] = []
try:
async with aiofiles.open(file_path, "r") as f:

Check failure on line 274 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.13)

Ruff (UP015)

dreadnode/agent/tools/fs.py:274:53: UP015 Unnecessary mode argument

Check failure on line 274 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.12)

Ruff (UP015)

dreadnode/agent/tools/fs.py:274:53: UP015 Unnecessary mode argument

Check failure on line 274 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.11)

Ruff (UP015)

dreadnode/agent/tools/fs.py:274:53: UP015 Unnecessary mode argument

Check failure on line 274 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.10)

Ruff (UP015)

dreadnode/agent/tools/fs.py:274:53: UP015 Unnecessary mode argument
lines = await f.readlines()

for i, line in enumerate(lines):
if len(matches) >= max_results:
break

if regex.search(line):
line_num = i + 1
context_start = max(0, i - 1)
Expand All @@ -253,39 +284,85 @@
for j in range(context_start, context_end):
prefix = ">" if j == i else " "
line_text = lines[j].rstrip("\r\n")
context.append(f"{prefix} {j + 1}: {shorten_string(line_text, 80)}")
context.append(
f"{prefix} {j + 1}: {shorten_string(line_text, 80)}"
)

rel_path = self._relative(file_path)
matches.append(
file_matches.append(
GrepMatch(
path=rel_path,
line_number=line_num,
line=shorten_string(line.rstrip("\r\n"), 80),
context=context,
),
)
except (
FileNotFoundError,
PermissionError,
IsADirectoryError,
UnicodeDecodeError,
OSError,
) as e:
logger.warning(f"Error occurred while searching file {file_path}: {e}")

return file_matches

# Search files in parallel with concurrency limit
semaphore = asyncio.Semaphore(self.max_concurrent_reads)

async def search_file_limited(file_path: UPath) -> list[GrepMatch]:
"""Search a single file with semaphore to limit concurrency."""
async with semaphore:
return await search_file(file_path)

all_matches: list[GrepMatch] = []
results = await asyncio.gather(
*[search_file_limited(file_path) for file_path in files_to_search]
)

# Flatten results and respect max_results
for file_matches in results:
all_matches.extend(file_matches)
if len(all_matches) >= max_results:
break

return matches
return all_matches[:max_results]

@tool_method(variants=["write"], catch=True)
def write_file(
async def write_file(
self,
path: t.Annotated[str, "Path to write the file to"],
contents: t.Annotated[str, "Content to write to the file"],
) -> FilesystemItem:
"""Create or overwrite a file with the given contents."""
_path = self._safe_create_file(path)
with _path.open("w") as f:
f.write(contents)
_path = await self._safe_create_file(path)
async with aiofiles.open(_path, "w") as f:
await f.write(contents)

return FilesystemItem.from_path(_path, self._upath)

@tool_method(variants=["write"], catch=True)
def write_lines(
async def write_file_bytes(
self,
path: t.Annotated[str, "Path to write the file to"],
bytes: t.Annotated[bytes, "Bytes to write to the file"],
) -> FilesystemItem:
"""Create or overwrite a file with the given bytes."""
_path = await self._safe_create_file(path)
async with aiofiles.open(_path, "wb") as f:
await f.write(bytes)

return FilesystemItem.from_path(_path, self._upath)

@tool_method(variants=["write"], catch=True)
async def write_lines(
self,
path: t.Annotated[str, "Path to write to"],
contents: t.Annotated[str, "Content to write"],
insert_line: t.Annotated[int, "Line number to insert at (negative counts from end)"] = -1,
insert_line: t.Annotated[
int, "Line number to insert at (negative counts from end)"
] = -1,
mode: t.Annotated[str, "'insert' or 'overwrite'"] = "insert",
) -> FilesystemItem:
"""
Expand All @@ -295,11 +372,11 @@
if mode not in ["insert", "overwrite"]:
raise ValueError("Invalid mode. Use 'insert' or 'overwrite'")

_path = self._safe_create_file(path)
_path = await self._safe_create_file(path)

lines: list[str] = []
with _path.open("r") as f:
lines = f.readlines()
async with aiofiles.open(_path, "r") as f:

Check failure on line 378 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.13)

Ruff (UP015)

dreadnode/agent/tools/fs.py:378:41: UP015 Unnecessary mode argument

Check failure on line 378 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.12)

Ruff (UP015)

dreadnode/agent/tools/fs.py:378:41: UP015 Unnecessary mode argument

Check failure on line 378 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.11)

Ruff (UP015)

dreadnode/agent/tools/fs.py:378:41: UP015 Unnecessary mode argument

Check failure on line 378 in dreadnode/agent/tools/fs.py

View workflow job for this annotation

GitHub Actions / Python - Lint, Typecheck, Test (3.10)

Ruff (UP015)

dreadnode/agent/tools/fs.py:378:41: UP015 Unnecessary mode argument
lines = await f.readlines()

# Normalize line endings in content
content_lines = [
Expand All @@ -319,24 +396,24 @@
elif mode == "overwrite":
lines[insert_line : insert_line + len(content_lines)] = content_lines

with _path.open("w") as f:
f.writelines(lines)
async with aiofiles.open(_path, "w") as f:
await f.writelines(lines)

return FilesystemItem.from_path(_path, self._upath)

@tool_method(variants=["write"], catch=True)
def mkdir(
async def mkdir(
self,
path: t.Annotated[str, "Directory path to create"],
) -> FilesystemItem:
"""Create a directory and any necessary parent directories."""
dir_path = self._resolve(path)
dir_path.mkdir(parents=True, exist_ok=True)
await asyncio.to_thread(lambda: dir_path.mkdir(parents=True, exist_ok=True))

return FilesystemItem.from_path(dir_path, self._upath)

@tool_method(variants=["write"], catch=True)
def mv(
async def mv(
self,
src: t.Annotated[str, "Source path"],
dest: t.Annotated[str, "Destination path"],
Expand All @@ -348,14 +425,16 @@
if not src_path.exists():
raise ValueError(f"'{src}' not found")

dest_path.parent.mkdir(parents=True, exist_ok=True)
await asyncio.to_thread(
lambda: dest_path.parent.mkdir(parents=True, exist_ok=True)
)

src_path.rename(dest_path)
await asyncio.to_thread(lambda: src_path.rename(dest_path))

return FilesystemItem.from_path(dest_path, self._upath)

@tool_method(variants=["write"], catch=True)
def cp(
async def cp(
self,
src: t.Annotated[str, "Source file"],
dest: t.Annotated[str, "Destination path"],
Expand All @@ -370,15 +449,21 @@
if not src_path.is_file():
raise ValueError(f"'{src}' is not a file")

dest_path.parent.mkdir(parents=True, exist_ok=True)
await asyncio.to_thread(
lambda: dest_path.parent.mkdir(parents=True, exist_ok=True)
)

with src_path.open("rb") as src_file, dest_path.open("wb") as dest_file:
dest_file.write(src_file.read())
async with (
aiofiles.open(src_path, "rb") as src_file,
aiofiles.open(dest_path, "wb") as dest_file,
):
content = await src_file.read()
await dest_file.write(content)

return FilesystemItem.from_path(dest_path, self._upath)

@tool_method(variants=["write"], catch=True)
def delete(
async def delete(
self,
path: t.Annotated[str, "File or directory"],
) -> bool:
Expand All @@ -388,8 +473,8 @@
raise ValueError(f"'{path}' not found")

if _path.is_dir():
_path.rmdir()
await asyncio.to_thread(_path.rmdir)
else:
_path.unlink()
await asyncio.to_thread(_path.unlink)

return True
Loading
Loading