Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 90 additions & 43 deletions openhands-sdk/openhands/sdk/agent/acp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,12 @@
)


# Seconds to wait after prompt() for pending session_update notifications
# to be processed. This is a best-effort workaround: the ACP protocol does
# not currently signal when all notifications for a turn have been delivered,
# so we yield to the event loop and then sleep briefly to allow in-flight
# handlers to finish. Override via ACP_NOTIFICATION_DRAIN_DELAY for slow or
# remote servers.
# TODO(https://github.com/agentclientprotocol/agent-client-protocol/issues/554):
# Replace with protocol-level synchronization once ACP supports a
# "turn complete" notification.
_NOTIFICATION_DRAIN_DELAY: float = float(
os.environ.get("ACP_NOTIFICATION_DRAIN_DELAY", "0.1")
)
# Maximum seconds to wait for a UsageUpdate notification after prompt()
# returns. The ACP server writes UsageUpdate to the wire before the
# PromptResponse, so under normal conditions the notification handler
# completes almost immediately. This timeout is a safety net for slow
# or remote servers.
_USAGE_UPDATE_TIMEOUT: float = float(os.environ.get("ACP_USAGE_UPDATE_TIMEOUT", "2.0"))

# Retry configuration for transient ACP connection errors.
# These errors can occur when the connection drops mid-conversation but the
Expand All @@ -103,17 +97,6 @@
_STREAM_READER_LIMIT: int = 100 * 1024 * 1024 # 100 MiB


async def _drain_notifications() -> None:
"""Best-effort drain of pending ``session_update`` notifications.

ACP does not yet signal when all notifications for a turn have been
delivered (see TODO above). We yield to the event loop so already-queued
handlers run, then sleep briefly to allow in-flight IO handlers to finish.
"""
await asyncio.sleep(0)
await asyncio.sleep(_NOTIFICATION_DRAIN_DELAY)


def _make_dummy_llm() -> LLM:
"""Create a dummy LLM that should never be called directly."""
return LLM(model="acp-managed")
Expand Down Expand Up @@ -215,8 +198,12 @@ def __init__(self) -> None:
self.on_token: Any = None # ConversationTokenCallbackType | None
# Telemetry state from UsageUpdate (persists across turns)
self._last_cost: float = 0.0 # last cumulative cost seen
self._context_window: int = 0 # context window size from ACP
self._llm_ref: Any = None # reference to the sentinel LLM
self._last_cost_by_session: dict[str, float] = {}
self._context_window: int = 0 # last context window seen
self._context_window_by_session: dict[str, int] = {}
# Per-turn synchronization for UsageUpdate notifications.
self._turn_usage_updates: dict[str, Any] = {}
self._usage_received: dict[str, asyncio.Event] = {}
# Fork session state for ask_agent() — guarded by _fork_lock to
# prevent concurrent ask_agent() calls from colliding.
self._fork_lock = threading.Lock()
Expand All @@ -228,9 +215,27 @@ def reset(self) -> None:
self.accumulated_thoughts.clear()
self.accumulated_tool_calls.clear()
self.on_token = None
self._turn_usage_updates.clear()
self._usage_received.clear()
# Note: telemetry state (_last_cost, _context_window, etc.)
# is intentionally NOT cleared — it accumulates across turns.

def prepare_usage_sync(self, session_id: str) -> asyncio.Event:
"""Prepare per-turn UsageUpdate synchronization for a session."""
event = asyncio.Event()
self._usage_received[session_id] = event
self._turn_usage_updates.pop(session_id, None)
return event

def get_turn_usage_update(self, session_id: str) -> Any:
"""Return the latest UsageUpdate observed for the current turn."""
return self._turn_usage_updates.get(session_id)

def pop_turn_usage_update(self, session_id: str) -> Any:
"""Consume per-turn UsageUpdate synchronization state for a session."""
self._usage_received.pop(session_id, None)
return self._turn_usage_updates.pop(session_id, None)

# -- Client protocol methods ------------------------------------------

async def session_update(
Expand Down Expand Up @@ -261,14 +266,13 @@ async def session_update(
if isinstance(update.content, TextContentBlock):
self.accumulated_thoughts.append(update.content.text)
elif isinstance(update, UsageUpdate):
# Update context window size
# Store the update for step()/ask_agent() to process in one place.
self._context_window = update.size
# Record incremental cost
if update.cost is not None and self._llm_ref is not None:
delta = update.cost.amount - self._last_cost
if delta > 0:
self._llm_ref.metrics.add_cost(delta)
self._last_cost = update.cost.amount
self._context_window_by_session[session_id] = update.size
self._turn_usage_updates[session_id] = update
event = self._usage_received.get(session_id)
if event is not None:
event.set()
elif isinstance(update, ToolCallStart):
self.accumulated_tool_calls.append(
{
Expand Down Expand Up @@ -458,14 +462,24 @@ def _record_usage(
response: PromptResponse | None,
session_id: str,
elapsed: float | None = None,
usage_update: UsageUpdate | None = None,
) -> None:
"""Record token usage, latency, and notify stats callback from a PromptResponse.
"""Record cost, token usage, latency, and notify stats callback once.

Args:
response: The ACP PromptResponse (may carry a ``usage`` field).
session_id: Session identifier used as the response_id for metrics.
elapsed: Wall-clock seconds for this prompt round-trip (optional).
usage_update: The synchronized ACP UsageUpdate for this turn, if any.
"""
if usage_update is not None and usage_update.cost is not None:
Comment thread
simonrosenberg marked this conversation as resolved.
last_cost = self._client._last_cost_by_session.get(session_id, 0.0)
delta = usage_update.cost.amount - last_cost
if delta > 0:
self.llm.metrics.add_cost(delta)
self._client._last_cost_by_session[session_id] = usage_update.cost.amount
self._client._last_cost = usage_update.cost.amount

if response is not None and response.usage is not None:
usage = response.usage
self.llm.metrics.add_token_usage(
Expand All @@ -474,7 +488,9 @@ def _record_usage(
cache_read_tokens=usage.cached_read_tokens or 0,
cache_write_tokens=usage.cached_write_tokens or 0,
reasoning_tokens=usage.thought_tokens or 0,
context_window=self._client._context_window,
context_window=self._client._context_window_by_session.get(
session_id, self._client._context_window
),
response_id=session_id,
)

Expand Down Expand Up @@ -568,7 +584,6 @@ def init_state(
def _start_acp_server(self, state: ConversationState) -> None:
"""Start the ACP subprocess and initialize the session."""
client = _OpenHandsACPBridge()
client._llm_ref = self.llm
self._client = client

# Build environment: inherit current env + ACP extras
Expand Down Expand Up @@ -712,11 +727,22 @@ def step(
try:

async def _prompt() -> PromptResponse:
usage_sync = self._client.prepare_usage_sync(self._session_id or "")
response = await self._conn.prompt(
[text_block(user_message)],
self._session_id,
)
await _drain_notifications()
if self._client.get_turn_usage_update(self._session_id or "") is None:
Comment thread
simonrosenberg marked this conversation as resolved.
try:
await asyncio.wait_for(
usage_sync.wait(), timeout=_USAGE_UPDATE_TIMEOUT
)
except TimeoutError:
logger.warning(
"UsageUpdate not received within %.1fs for session %s",
_USAGE_UPDATE_TIMEOUT,
self._session_id,
)
return response

# Send prompt to ACP server with retry logic for connection errors.
Expand All @@ -736,9 +762,8 @@ async def _prompt() -> PromptResponse:
response = self._executor.run_async(
_prompt, timeout=self.acp_prompt_timeout
)
break # Success, exit retry loop
break
except TimeoutError:
# Timeout is handled separately below, don't retry
raise
except _RETRIABLE_CONNECTION_ERRORS as e:
if attempt < max_retries:
Expand All @@ -754,17 +779,22 @@ async def _prompt() -> PromptResponse:
e,
)
time.sleep(delay)
# Reset accumulators for retry (partial state may be stale)
self._client.reset()
self._client.on_token = on_token
else:
# Max retries exceeded
raise

elapsed = time.monotonic() - t0
logger.info("ACP prompt returned in %.1fs", elapsed)

self._record_usage(response, self._session_id or "", elapsed=elapsed)
session_id = self._session_id or ""
usage_update = self._client.pop_turn_usage_update(session_id)
self._record_usage(
response,
session_id,
elapsed=elapsed,
usage_update=usage_update,
)

# Emit ACPToolCallEvents for each accumulated tool call
for tc in self._client.accumulated_tool_calls:
Expand Down Expand Up @@ -911,15 +941,32 @@ async def _fork_and_prompt() -> str:
client._fork_accumulated_text.clear()
try:
fork_t0 = time.monotonic()
usage_sync = client.prepare_usage_sync(fork_session_id)
response = await self._conn.prompt(
[text_block(question)],
fork_session_id,
)
await _drain_notifications()
if client.get_turn_usage_update(fork_session_id) is None:
try:
await asyncio.wait_for(
usage_sync.wait(), timeout=_USAGE_UPDATE_TIMEOUT
)
except TimeoutError:
logger.warning(
"UsageUpdate not received within %.1fs for fork session %s",
_USAGE_UPDATE_TIMEOUT,
fork_session_id,
)
fork_elapsed = time.monotonic() - fork_t0

result = "".join(client._fork_accumulated_text)
self._record_usage(response, fork_session_id, elapsed=fork_elapsed)
usage_update = client.pop_turn_usage_update(fork_session_id)
self._record_usage(
response,
fork_session_id,
elapsed=fork_elapsed,
usage_update=usage_update,
)
return result
finally:
client._fork_session_id = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,17 @@ def _get_conversation_info(self) -> dict:
return self._cached_state

# Fallback to REST API if no cached state
resp = _send_request(
self._client,
"GET",
f"{self._conversation_info_base_path}/{self._conversation_id}",
)
state = resp.json()
return self.refresh_from_server()

def refresh_from_server(self) -> dict:
"""Fetch and cache the latest authoritative conversation state."""
resp = _send_request(
self._client,
"GET",
f"{self._conversation_info_base_path}/{self._conversation_id}",
)
state = resp.json()
with self._lock:
self._cached_state = state
return state

Expand Down Expand Up @@ -1039,6 +1044,7 @@ def _wait_for_run_completion(
ws_status,
elapsed,
)
self._state.refresh_from_server()
return
except Empty:
pass # Queue.get() timed out, fall through to REST polling
Expand All @@ -1064,11 +1070,15 @@ def _wait_for_run_completion(
logger.info(
"Run completed via REST fallback after %d consecutive "
"terminal polls (status: %s, elapsed: %.1fs). "
"Reconciling events...",
"Refreshing final state and reconciling events...",
consecutive_terminal_polls,
status,
elapsed,
)
final_info = self._state.refresh_from_server()
self._handle_conversation_status(
final_info.get("execution_status")
)
# Reconcile events to catch any that were missed via WS.
# This is only called in the fallback path, so it doesn't
# add overhead in the common case where WS works.
Expand Down
Loading
Loading