Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 172 additions & 52 deletions openhands-sdk/openhands/sdk/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from collections.abc import Callable
from dataclasses import dataclass, field

from pydantic import ValidationError, model_validator
from pydantic import PrivateAttr, ValidationError, model_validator

import openhands.sdk.security.analyzer as analyzer
import openhands.sdk.security.risk as risk
from openhands.sdk.agent.base import AgentBase
from openhands.sdk.agent.critic_mixin import CriticMixin
from openhands.sdk.agent.parallel_executor import ParallelToolExecutor
from openhands.sdk.agent.utils import (
fix_malformed_tool_arguments,
make_llm_completion,
Expand All @@ -22,6 +25,7 @@
from openhands.sdk.event import (
ActionEvent,
AgentErrorEvent,
Event,
MessageEvent,
ObservationEvent,
SystemPromptEvent,
Expand Down Expand Up @@ -72,6 +76,133 @@
INIT_STATE_PREFIX_SCAN_WINDOW = 3


@dataclass(frozen=True, slots=True)
class _ActionBatch:
"""Immutable result of preparing a batch of actions for execution.

Owns the full lifecycle of a tool-call batch: preparation (truncation,
blocked-action partitioning, execution), event emission, and post-batch
state transitions. Agent-specific logic (iterative refinement, state
mutation) is injected via callables so the batch stays decoupled from
the Agent class.
"""

action_events: list[ActionEvent]
has_finish: bool
blocked_reasons: dict[str, str] = field(default_factory=dict)
results_by_id: dict[str, list[Event]] = field(default_factory=dict)

@staticmethod
def _truncate_at_finish(
action_events: list[ActionEvent],
) -> tuple[list[ActionEvent], bool]:
Comment thread
xingyaoww marked this conversation as resolved.
"""
Return (events[:finish+1], True) or (events, False).
Discards and logs any calls after FinishTool.
"""
finish_idx = next(
(
i
for i, ae in enumerate(action_events)
if ae.tool_name == FinishTool.name
),
None,
)
if finish_idx is None:
return action_events, False

discarded = action_events[finish_idx + 1 :]
if discarded:
names = [ae.tool_name for ae in discarded]
logger.warning(
f"Discarding {len(discarded)} tool call(s) "
f"after FinishTool: {', '.join(names)}"
)
return action_events[: finish_idx + 1], True

@classmethod
def prepare(
cls,
action_events: list[ActionEvent],
state: ConversationState,
executor: ParallelToolExecutor,
tool_runner: Callable[[ActionEvent], list[Event]],
) -> "_ActionBatch":
"""Truncate, partition blocked actions, execute the rest, return the batch."""
action_events, has_finish = cls._truncate_at_finish(action_events)

blocked_reasons: dict[str, str] = {}
executable: list[ActionEvent] = []
for ae in action_events:
reason = state.pop_blocked_action(ae.id)
if reason is not None:
blocked_reasons[ae.id] = reason
else:
executable.append(ae)

executed_results = executor.execute_batch(executable, tool_runner)
results_by_id = dict(zip([ae.id for ae in executable], executed_results))

return cls(
action_events=action_events,
has_finish=has_finish,
blocked_reasons=blocked_reasons,
results_by_id=results_by_id,
)

def emit(self, on_event: ConversationCallbackType) -> None:
"""Emit all events in original action order."""
for ae in self.action_events:
reason = self.blocked_reasons.get(ae.id)
if reason is not None:
logger.info(f"Action '{ae.tool_name}' blocked by hook: {reason}")
Comment thread
xingyaoww marked this conversation as resolved.
on_event(
UserRejectObservation(
action_id=ae.id,
tool_name=ae.tool_name,
tool_call_id=ae.tool_call_id,
rejection_reason=reason,
rejection_source="hook",
)
)
else:
for event in self.results_by_id[ae.id]:
on_event(event)

def finalize(
self,
on_event: ConversationCallbackType,
check_iterative_refinement: Callable[[ActionEvent], tuple[bool, str | None]],
mark_finished: Callable[[], None],
) -> None:
"""Transition state after FinishTool, or inject iterative-refinement followup.

Args:
on_event: Callback for emitting events.
check_iterative_refinement: Returns (should_continue, followup)
for a FinishTool action event.
mark_finished: Called to set the conversation execution status
to FINISHED when the agent is done.
"""
# Nothing to finalise: no FinishTool, or it was blocked by a hook.
if not self.has_finish or self.action_events[-1].id in self.blocked_reasons:
return

should_continue, followup = check_iterative_refinement(self.action_events[-1])
if should_continue and followup:
on_event(
MessageEvent(
source="user",
llm_message=Message(
role="user",
content=[TextContent(text=followup)],
),
)
)
else:
mark_finished()


class Agent(CriticMixin, AgentBase):
"""Main agent implementation for OpenHands.

Expand All @@ -97,6 +228,10 @@ class Agent(CriticMixin, AgentBase):
```
"""

_parallel_executor: ParallelToolExecutor = PrivateAttr(
default_factory=ParallelToolExecutor
)

@model_validator(mode="before")
@classmethod
def _add_security_prompt_as_default(cls, data):
Expand Down Expand Up @@ -258,9 +393,27 @@ def _execute_actions(
conversation: LocalConversation,
action_events: list[ActionEvent],
on_event: ConversationCallbackType,
):
for action_event in action_events:
self._execute_action_event(conversation, action_event, on_event=on_event)
) -> None:
Comment thread
xingyaoww marked this conversation as resolved.
"""Prepare a batch, emit results, and handle finish."""
state = conversation.state
batch = _ActionBatch.prepare(
action_events,
state=state,
executor=self._parallel_executor,
tool_runner=lambda ae: self._execute_action_event(conversation, ae),
)
batch.emit(on_event)
batch.finalize(
on_event=on_event,
check_iterative_refinement=lambda ae: (
self._check_iterative_refinement(conversation, ae)
),
mark_finished=lambda: setattr(
state,
"execution_status",
ConversationExecutionStatus.FINISHED,
),
)

@observe(name="agent.step", ignore_inputs=["state", "on_event"])
def step(
Expand Down Expand Up @@ -659,38 +812,26 @@ def _get_action_event(
on_event(action_event)
return action_event

@observe(ignore_inputs=["state", "on_event"])
@observe()
def _execute_action_event(
self,
conversation: LocalConversation,
action_event: ActionEvent,
on_event: ConversationCallbackType,
):
"""Execute an action event and update the conversation state.
) -> list[Event]:
"""Execute a single tool and return the resulting events.

It will call the tool's executor and update the state & call callback fn
with the observation.
Called from parallel threads by _execute_actions. This method must
not mutate shared conversation state (blocked_actions,
execution_status) — those transitions are handled by the caller
on the main thread.

If the action was blocked by a PreToolUse hook (recorded in
state.blocked_actions), a UserRejectObservation is emitted instead
of executing the action.
"""
state = conversation.state

# Check if this action was blocked by a PreToolUse hook
reason = state.pop_blocked_action(action_event.id)
if reason is not None:
logger.info(f"Action '{action_event.tool_name}' blocked by hook: {reason}")
rejection = UserRejectObservation(
action_id=action_event.id,
tool_name=action_event.tool_name,
tool_call_id=action_event.tool_call_id,
rejection_reason=reason,
rejection_source="hook",
)
on_event(rejection)
return rejection
Note: the tool itself receives ``conversation`` and may mutate it
Comment thread
VascoSch92 marked this conversation as resolved.
(e.g. filesystem, working directory). Thread safety of individual
tools is the tool's responsibility.

Returns a list of events (observation or error). Events are NOT
emitted here — the caller is responsible for emitting them in order.
"""
tool = self.tools_map.get(action_event.tool_name, None)
if tool is None:
raise RuntimeError(
Expand Down Expand Up @@ -720,36 +861,15 @@ def _execute_action_event(
tool_name=tool.name,
tool_call_id=action_event.tool_call.id,
)
on_event(error_event)
return error_event
return [error_event]

obs_event = ObservationEvent(
observation=observation,
action_id=action_event.id,
tool_name=tool.name,
tool_call_id=action_event.tool_call.id,
)
on_event(obs_event)

# Set conversation state
if tool.name == FinishTool.name:
# Check if iterative refinement should continue
should_continue, followup = self._check_iterative_refinement(
conversation, action_event
)
if should_continue and followup:
# Send follow-up message and continue agent loop
followup_msg = MessageEvent(
source="user",
llm_message=Message(
role="user", content=[TextContent(text=followup)]
),
)
on_event(followup_msg)
# Don't set FINISHED - let the agent continue
else:
state.execution_status = ConversationExecutionStatus.FINISHED
return obs_event
return [obs_event]

def _maybe_emit_vllm_tokens(
self, llm_response: LLMResponse, on_event: ConversationCallbackType
Expand Down
Loading
Loading