Skip to content

Commit 6b02df0

Browse files
simonrosenbergDebug Agent
andauthored
fix: synchronize ACP telemetry and refresh remote final state (#2460)
Co-authored-by: Debug Agent <debug@example.com>
1 parent 2d027b4 commit 6b02df0

File tree

5 files changed

+403
-160
lines changed

5 files changed

+403
-160
lines changed

openhands-sdk/openhands/sdk/agent/acp_agent.py

Lines changed: 90 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,12 @@
6969
)
7070

7171

72-
# Seconds to wait after prompt() for pending session_update notifications
73-
# to be processed. This is a best-effort workaround: the ACP protocol does
74-
# not currently signal when all notifications for a turn have been delivered,
75-
# so we yield to the event loop and then sleep briefly to allow in-flight
76-
# handlers to finish. Override via ACP_NOTIFICATION_DRAIN_DELAY for slow or
77-
# remote servers.
78-
# TODO(https://github.com/agentclientprotocol/agent-client-protocol/issues/554):
79-
# Replace with protocol-level synchronization once ACP supports a
80-
# "turn complete" notification.
81-
_NOTIFICATION_DRAIN_DELAY: float = float(
82-
os.environ.get("ACP_NOTIFICATION_DRAIN_DELAY", "0.1")
83-
)
72+
# Maximum seconds to wait for a UsageUpdate notification after prompt()
73+
# returns. The ACP server writes UsageUpdate to the wire before the
74+
# PromptResponse, so under normal conditions the notification handler
75+
# completes almost immediately. This timeout is a safety net for slow
76+
# or remote servers.
77+
_USAGE_UPDATE_TIMEOUT: float = float(os.environ.get("ACP_USAGE_UPDATE_TIMEOUT", "2.0"))
8478

8579
# Retry configuration for transient ACP connection errors.
8680
# These errors can occur when the connection drops mid-conversation but the
@@ -103,17 +97,6 @@
10397
_STREAM_READER_LIMIT: int = 100 * 1024 * 1024 # 100 MiB
10498

10599

106-
async def _drain_notifications() -> None:
107-
"""Best-effort drain of pending ``session_update`` notifications.
108-
109-
ACP does not yet signal when all notifications for a turn have been
110-
delivered (see TODO above). We yield to the event loop so already-queued
111-
handlers run, then sleep briefly to allow in-flight IO handlers to finish.
112-
"""
113-
await asyncio.sleep(0)
114-
await asyncio.sleep(_NOTIFICATION_DRAIN_DELAY)
115-
116-
117100
def _make_dummy_llm() -> LLM:
118101
"""Create a dummy LLM that should never be called directly."""
119102
return LLM(model="acp-managed")
@@ -215,8 +198,12 @@ def __init__(self) -> None:
215198
self.on_token: Any = None # ConversationTokenCallbackType | None
216199
# Telemetry state from UsageUpdate (persists across turns)
217200
self._last_cost: float = 0.0 # last cumulative cost seen
218-
self._context_window: int = 0 # context window size from ACP
219-
self._llm_ref: Any = None # reference to the sentinel LLM
201+
self._last_cost_by_session: dict[str, float] = {}
202+
self._context_window: int = 0 # last context window seen
203+
self._context_window_by_session: dict[str, int] = {}
204+
# Per-turn synchronization for UsageUpdate notifications.
205+
self._turn_usage_updates: dict[str, Any] = {}
206+
self._usage_received: dict[str, asyncio.Event] = {}
220207
# Fork session state for ask_agent() — guarded by _fork_lock to
221208
# prevent concurrent ask_agent() calls from colliding.
222209
self._fork_lock = threading.Lock()
@@ -228,9 +215,27 @@ def reset(self) -> None:
228215
self.accumulated_thoughts.clear()
229216
self.accumulated_tool_calls.clear()
230217
self.on_token = None
218+
self._turn_usage_updates.clear()
219+
self._usage_received.clear()
231220
# Note: telemetry state (_last_cost, _context_window, etc.)
232221
# is intentionally NOT cleared — it accumulates across turns.
233222

223+
def prepare_usage_sync(self, session_id: str) -> asyncio.Event:
224+
"""Prepare per-turn UsageUpdate synchronization for a session."""
225+
event = asyncio.Event()
226+
self._usage_received[session_id] = event
227+
self._turn_usage_updates.pop(session_id, None)
228+
return event
229+
230+
def get_turn_usage_update(self, session_id: str) -> Any:
231+
"""Return the latest UsageUpdate observed for the current turn."""
232+
return self._turn_usage_updates.get(session_id)
233+
234+
def pop_turn_usage_update(self, session_id: str) -> Any:
235+
"""Consume per-turn UsageUpdate synchronization state for a session."""
236+
self._usage_received.pop(session_id, None)
237+
return self._turn_usage_updates.pop(session_id, None)
238+
234239
# -- Client protocol methods ------------------------------------------
235240

236241
async def session_update(
@@ -261,14 +266,13 @@ async def session_update(
261266
if isinstance(update.content, TextContentBlock):
262267
self.accumulated_thoughts.append(update.content.text)
263268
elif isinstance(update, UsageUpdate):
264-
# Update context window size
269+
# Store the update for step()/ask_agent() to process in one place.
265270
self._context_window = update.size
266-
# Record incremental cost
267-
if update.cost is not None and self._llm_ref is not None:
268-
delta = update.cost.amount - self._last_cost
269-
if delta > 0:
270-
self._llm_ref.metrics.add_cost(delta)
271-
self._last_cost = update.cost.amount
271+
self._context_window_by_session[session_id] = update.size
272+
self._turn_usage_updates[session_id] = update
273+
event = self._usage_received.get(session_id)
274+
if event is not None:
275+
event.set()
272276
elif isinstance(update, ToolCallStart):
273277
self.accumulated_tool_calls.append(
274278
{
@@ -458,14 +462,24 @@ def _record_usage(
458462
response: PromptResponse | None,
459463
session_id: str,
460464
elapsed: float | None = None,
465+
usage_update: UsageUpdate | None = None,
461466
) -> None:
462-
"""Record token usage, latency, and notify stats callback from a PromptResponse.
467+
"""Record cost, token usage, latency, and notify stats callback once.
463468
464469
Args:
465470
response: The ACP PromptResponse (may carry a ``usage`` field).
466471
session_id: Session identifier used as the response_id for metrics.
467472
elapsed: Wall-clock seconds for this prompt round-trip (optional).
473+
usage_update: The synchronized ACP UsageUpdate for this turn, if any.
468474
"""
475+
if usage_update is not None and usage_update.cost is not None:
476+
last_cost = self._client._last_cost_by_session.get(session_id, 0.0)
477+
delta = usage_update.cost.amount - last_cost
478+
if delta > 0:
479+
self.llm.metrics.add_cost(delta)
480+
self._client._last_cost_by_session[session_id] = usage_update.cost.amount
481+
self._client._last_cost = usage_update.cost.amount
482+
469483
if response is not None and response.usage is not None:
470484
usage = response.usage
471485
self.llm.metrics.add_token_usage(
@@ -474,7 +488,9 @@ def _record_usage(
474488
cache_read_tokens=usage.cached_read_tokens or 0,
475489
cache_write_tokens=usage.cached_write_tokens or 0,
476490
reasoning_tokens=usage.thought_tokens or 0,
477-
context_window=self._client._context_window,
491+
context_window=self._client._context_window_by_session.get(
492+
session_id, self._client._context_window
493+
),
478494
response_id=session_id,
479495
)
480496

@@ -568,7 +584,6 @@ def init_state(
568584
def _start_acp_server(self, state: ConversationState) -> None:
569585
"""Start the ACP subprocess and initialize the session."""
570586
client = _OpenHandsACPBridge()
571-
client._llm_ref = self.llm
572587
self._client = client
573588

574589
# Build environment: inherit current env + ACP extras
@@ -712,11 +727,22 @@ def step(
712727
try:
713728

714729
async def _prompt() -> PromptResponse:
730+
usage_sync = self._client.prepare_usage_sync(self._session_id or "")
715731
response = await self._conn.prompt(
716732
[text_block(user_message)],
717733
self._session_id,
718734
)
719-
await _drain_notifications()
735+
if self._client.get_turn_usage_update(self._session_id or "") is None:
736+
try:
737+
await asyncio.wait_for(
738+
usage_sync.wait(), timeout=_USAGE_UPDATE_TIMEOUT
739+
)
740+
except TimeoutError:
741+
logger.warning(
742+
"UsageUpdate not received within %.1fs for session %s",
743+
_USAGE_UPDATE_TIMEOUT,
744+
self._session_id,
745+
)
720746
return response
721747

722748
# Send prompt to ACP server with retry logic for connection errors.
@@ -736,9 +762,8 @@ async def _prompt() -> PromptResponse:
736762
response = self._executor.run_async(
737763
_prompt, timeout=self.acp_prompt_timeout
738764
)
739-
break # Success, exit retry loop
765+
break
740766
except TimeoutError:
741-
# Timeout is handled separately below, don't retry
742767
raise
743768
except _RETRIABLE_CONNECTION_ERRORS as e:
744769
if attempt < max_retries:
@@ -754,17 +779,22 @@ async def _prompt() -> PromptResponse:
754779
e,
755780
)
756781
time.sleep(delay)
757-
# Reset accumulators for retry (partial state may be stale)
758782
self._client.reset()
759783
self._client.on_token = on_token
760784
else:
761-
# Max retries exceeded
762785
raise
763786

764787
elapsed = time.monotonic() - t0
765788
logger.info("ACP prompt returned in %.1fs", elapsed)
766789

767-
self._record_usage(response, self._session_id or "", elapsed=elapsed)
790+
session_id = self._session_id or ""
791+
usage_update = self._client.pop_turn_usage_update(session_id)
792+
self._record_usage(
793+
response,
794+
session_id,
795+
elapsed=elapsed,
796+
usage_update=usage_update,
797+
)
768798

769799
# Emit ACPToolCallEvents for each accumulated tool call
770800
for tc in self._client.accumulated_tool_calls:
@@ -911,15 +941,32 @@ async def _fork_and_prompt() -> str:
911941
client._fork_accumulated_text.clear()
912942
try:
913943
fork_t0 = time.monotonic()
944+
usage_sync = client.prepare_usage_sync(fork_session_id)
914945
response = await self._conn.prompt(
915946
[text_block(question)],
916947
fork_session_id,
917948
)
918-
await _drain_notifications()
949+
if client.get_turn_usage_update(fork_session_id) is None:
950+
try:
951+
await asyncio.wait_for(
952+
usage_sync.wait(), timeout=_USAGE_UPDATE_TIMEOUT
953+
)
954+
except TimeoutError:
955+
logger.warning(
956+
"UsageUpdate not received within %.1fs for fork session %s",
957+
_USAGE_UPDATE_TIMEOUT,
958+
fork_session_id,
959+
)
919960
fork_elapsed = time.monotonic() - fork_t0
920961

921962
result = "".join(client._fork_accumulated_text)
922-
self._record_usage(response, fork_session_id, elapsed=fork_elapsed)
963+
usage_update = client.pop_turn_usage_update(fork_session_id)
964+
self._record_usage(
965+
response,
966+
fork_session_id,
967+
elapsed=fork_elapsed,
968+
usage_update=usage_update,
969+
)
923970
return result
924971
finally:
925972
client._fork_session_id = None

openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,17 @@ def _get_conversation_info(self) -> dict:
437437
return self._cached_state
438438

439439
# Fallback to REST API if no cached state
440-
resp = _send_request(
441-
self._client,
442-
"GET",
443-
f"{self._conversation_info_base_path}/{self._conversation_id}",
444-
)
445-
state = resp.json()
440+
return self.refresh_from_server()
441+
442+
def refresh_from_server(self) -> dict:
443+
"""Fetch and cache the latest authoritative conversation state."""
444+
resp = _send_request(
445+
self._client,
446+
"GET",
447+
f"{self._conversation_info_base_path}/{self._conversation_id}",
448+
)
449+
state = resp.json()
450+
with self._lock:
446451
self._cached_state = state
447452
return state
448453

@@ -1039,6 +1044,7 @@ def _wait_for_run_completion(
10391044
ws_status,
10401045
elapsed,
10411046
)
1047+
self._state.refresh_from_server()
10421048
return
10431049
except Empty:
10441050
pass # Queue.get() timed out, fall through to REST polling
@@ -1064,11 +1070,15 @@ def _wait_for_run_completion(
10641070
logger.info(
10651071
"Run completed via REST fallback after %d consecutive "
10661072
"terminal polls (status: %s, elapsed: %.1fs). "
1067-
"Reconciling events...",
1073+
"Refreshing final state and reconciling events...",
10681074
consecutive_terminal_polls,
10691075
status,
10701076
elapsed,
10711077
)
1078+
final_info = self._state.refresh_from_server()
1079+
self._handle_conversation_status(
1080+
final_info.get("execution_status")
1081+
)
10721082
# Reconcile events to catch any that were missed via WS.
10731083
# This is only called in the fallback path, so it doesn't
10741084
# add overhead in the common case where WS works.

0 commit comments

Comments
 (0)