Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions src/ursa/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
"""

import re
import sqlite3
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (
Expand All @@ -41,6 +43,8 @@
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.store.base import BaseStore
from langgraph.store.sqlite import SqliteStore

from ursa.observability.timing import (
Telemetry, # for timing / telemetry / metrics
Expand All @@ -50,6 +54,17 @@
TState = TypeVar("TState", bound=Mapping[str, Any])


@dataclass(frozen=True, kw_only=True)
class AgentContext:
"""Immutable context provided during graph execution"""

workspace: Path
""" Workspace path for the agent """

tool_character_limit: int = 3000
""" Suggested limit on tool call responses """


def _to_snake(s: str) -> str:
"""Convert a string to snake_case format.

Expand Down Expand Up @@ -170,6 +185,11 @@ def name(self) -> str:
"""Agent name."""
return self.__class__.__name__

@property
def context(self) -> AgentContext:
"""Immutable run-scoped information provided to the Agent's graph"""
return AgentContext(workspace=self.workspace)

def add_node(
self,
f: Callable[..., Mapping[str, Any]],
Expand Down Expand Up @@ -512,11 +532,23 @@ def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
def compiled_graph(self) -> CompiledStateGraph:
"""Return the compiled StateGraph application for the agent."""
graph = self.build_graph()
compiled = graph.compile(checkpointer=self.checkpointer).with_config({
"recursion_limit": 50000
})
compiled = graph.compile(
checkpointer=self.checkpointer,
store=self.storage,
).with_config({"recursion_limit": 50000})
return self._finalize_graph(compiled)

@cached_property
def storage(self) -> BaseStore:
"""Create a SQLite-backed LangGraph store for persistent graph data."""
store_path = self.workspace / "graph_store.sqlite"
conn = sqlite3.connect(
store_path, check_same_thread=False, isolation_level=None
)
store = SqliteStore(conn)
store.setup()
return store

@final
def build_graph(self) -> StateGraph:
"""Build and return the StateGraph backing this agent."""
Expand Down
43 changes: 20 additions & 23 deletions src/ursa/agents/execution_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langchain_core.output_parsers import StrOutputParser
from langgraph.runtime import Runtime
from langgraph.types import Command

# Rich
from rich import get_console
from rich.markdown import Markdown
from rich.panel import Panel

from ursa.agents.base import AgentWithTools, BaseAgent
from ursa.agents.base import AgentContext, AgentWithTools, BaseAgent
from ursa.prompt_library.execution_prompts import (
executor_prompt,
get_safety_prompt,
Expand Down Expand Up @@ -82,15 +82,13 @@ class ExecutionState(TypedDict):
Fields:
- messages: list of messages (System/Human/AI/Tool).
- current_progress: short status string describing agent progress.
- code_files: list of filenames created or edited in the workspace.
- workspace: path to the working directory where files and commands run.
- symlinkdir: optional dict describing a symlink operation (source, dest,
is_linked).
"""

messages: list[AnyMessage]
current_progress: str
code_files: list[str]
workspace: Path
symlinkdir: dict
model: BaseChatModel
Expand Down Expand Up @@ -280,11 +278,11 @@ def _summarize_context(self, state: ExecutionState) -> ExecutionState:
pass

summarize_prompt = f"""
Your only tasks is to provide a detailed, comprehensive summary of the following
conversation.
Your only tasks is to provide a detailed, comprehensive summary of the following
conversation.

Your summary will be the only information retained from the conversation, so ensure
it contains all details that need to be remembered to meet the goals of the work.
Your summary will be the only information retained from the conversation, so ensure
it contains all details that need to be remembered to meet the goals of the work.

Conversation to summarize:
{conversation_to_summarize}
Expand Down Expand Up @@ -418,16 +416,10 @@ def tool_use(self, state: ExecutionState) -> ExecutionState:
for resp in update:
if isinstance(resp, Command):
new_state["messages"].extend(resp.update["messages"])
new_state.setdefault("code_files", []).extend(
resp.update["code_files"]
)
else:
new_state["messages"].extend(resp["messages"])
elif isinstance(update, Command):
new_state["messages"].extend(update.update["messages"])
new_state.setdefault("code_files", []).extend(
update.update["code_files"]
)
except Exception as e:
print(f"SOMETHING IS WRONG WITH {update}: {e}")
new_state["messages"].extend(update["messages"])
Expand Down Expand Up @@ -514,7 +506,9 @@ def recap(self, state: ExecutionState) -> ExecutionState:
# 5) Return a partial state update with only the summary content.
return new_state

def safety_check(self, state: ExecutionState) -> ExecutionState:
def safety_check(
self, state: ExecutionState, runtime: Runtime[AgentContext]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to fail for me in any uses of the execution agent that traverse the graph because it needs two arguments but in the graph it only gets the state passed into it.

) -> ExecutionState:
"""Assess pending shell commands for safety and inject ToolMessages with results.

This method inspects the most recent AI tool calls, evaluates any run_command
Expand Down Expand Up @@ -544,15 +538,18 @@ def safety_check(self, state: ExecutionState) -> ExecutionState:
if tool_call["name"] != "run_command":
continue

query = tool_call["args"]["query"]
safety_result = StrOutputParser().invoke(
self.llm.invoke(
self.get_safety_prompt(
query, self.safe_codes, new_state.get("code_files", [])
),
self.build_config(tags=["safety_check"]),
if runtime.store is not None:
search_results = runtime.store.search(
("workspace", "file_edit"), limit=1000
)
)
edited_files = [item.key for item in search_results]
else:
edited_files = []
query = tool_call["args"]["query"]
safety_result = self.llm.invoke(
self.get_safety_prompt(query, self.safe_codes, edited_files),
self.build_config(tags=["safety_check"]),
).text

if "[NO]" in safety_result:
any_unsafe = True
Expand Down
14 changes: 5 additions & 9 deletions src/ursa/tools/read_file_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import os
from typing import Annotated

from langchain.tools import ToolRuntime
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState

from ursa.agents.base import AgentContext
from ursa.util.parse import read_pdf_text, read_text_file


# Tools for ExecutionAgent
@tool
def read_file(filename: str, state: Annotated[dict, InjectedState]) -> str:
def read_file(filename: str, runtime: ToolRuntime[AgentContext]) -> str:
"""
Reads in a file with a given filename into a string. Can read in PDF
or files that are text/ASCII. Uses a PDF parser if the filename ends
Expand All @@ -18,12 +15,11 @@ def read_file(filename: str, state: Annotated[dict, InjectedState]) -> str:
Args:
filename: string filename to read in
"""
workspace_dir = state["workspace"]
full_filename = os.path.join(workspace_dir, filename)
full_filename = runtime.context.workspace.joinpath(filename)

print("[READING]: ", full_filename)
try:
if full_filename.lower().endswith(".pdf"):
if full_filename.suffix == ".pdf":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if full_filename.suffix == ".pdf":
if full_filename.suffix.lower() == ".pdf":

Doesn't this need to keep the "lower" to ensure it's not case sensitive (some pdf files end in .PDF rather than .pdf for instance)

file_contents = read_pdf_text(full_filename)
else:
file_contents = read_text_file(full_filename)
Expand Down
27 changes: 8 additions & 19 deletions src/ursa/tools/run_command_tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import os
import subprocess
from typing import Annotated
from pathlib import Path

from langchain.tools import ToolRuntime
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState

# Global variables for the module.

# Set a limit for message characters - the user could overload
# that in their env, or maybe we could pull this out of the LLM parameters
MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "30000"))
from ursa.agents.base import AgentContext
from ursa.util.types import AsciiStr


@tool
def run_command(query: str, state: Annotated[dict, InjectedState]) -> str:
def run_command(query: AsciiStr, runtime: ToolRuntime[AgentContext]) -> str:
"""Execute a shell command in the workspace and return its combined output.

Runs the specified command using subprocess.run in the given workspace
Expand All @@ -29,9 +25,10 @@ def run_command(query: str, state: Annotated[dict, InjectedState]) -> str:
A formatted string with "STDOUT:" followed by the truncated stdout and
"STDERR:" followed by the truncated stderr.
"""
workspace_dir = state["workspace"]
workspace_dir = Path(runtime.context.workspace)

print("RUNNING: ", query)

try:
result = subprocess.run(
query,
Expand All @@ -45,18 +42,10 @@ def run_command(query: str, state: Annotated[dict, InjectedState]) -> str:
except KeyboardInterrupt:
print("Keyboard Interrupt of command: ", query)
stdout, stderr = "", "KeyboardInterrupt:"
except UnicodeDecodeError:
print(
f"Invalid Command: {query} - only 'utf-8' decodable characters allowed."
)
stdout, stderr = (
"",
f"Invalid Command: {query} - only 'utf-8' decodable characters allowed.:",
)

# Fit BOTH streams under a single overall cap
stdout_fit, stderr_fit = _fit_streams_to_budget(
stdout or "", stderr or "", MAX_TOOL_MSG_CHARS
stdout or "", stderr or "", runtime.context.tool_character_limit
)

print("STDOUT: ", stdout_fit)
Expand Down
Loading
Loading