6969 )
7070
7171
72- # Seconds to wait after prompt() for pending session_update notifications
73- # to be processed. This is a best-effort workaround: the ACP protocol does
74- # not currently signal when all notifications for a turn have been delivered,
75- # so we yield to the event loop and then sleep briefly to allow in-flight
76- # handlers to finish. Override via ACP_NOTIFICATION_DRAIN_DELAY for slow or
77- # remote servers.
78- # TODO(https://github.com/agentclientprotocol/agent-client-protocol/issues/554):
79- # Replace with protocol-level synchronization once ACP supports a
80- # "turn complete" notification.
81- _NOTIFICATION_DRAIN_DELAY : float = float (
82- os .environ .get ("ACP_NOTIFICATION_DRAIN_DELAY" , "0.1" )
83- )
72+ # Maximum seconds to wait for a UsageUpdate notification after prompt()
73+ # returns. The ACP server writes UsageUpdate to the wire before the
74+ # PromptResponse, so under normal conditions the notification handler
75+ # completes almost immediately. This timeout is a safety net for slow
76+ # or remote servers.
77+ _USAGE_UPDATE_TIMEOUT : float = float (os .environ .get ("ACP_USAGE_UPDATE_TIMEOUT" , "2.0" ))
8478
8579# Retry configuration for transient ACP connection errors.
8680# These errors can occur when the connection drops mid-conversation but the
10397_STREAM_READER_LIMIT : int = 100 * 1024 * 1024 # 100 MiB
10498
10599
106- async def _drain_notifications () -> None :
107- """Best-effort drain of pending ``session_update`` notifications.
108-
109- ACP does not yet signal when all notifications for a turn have been
110- delivered (see TODO above). We yield to the event loop so already-queued
111- handlers run, then sleep briefly to allow in-flight IO handlers to finish.
112- """
113- await asyncio .sleep (0 )
114- await asyncio .sleep (_NOTIFICATION_DRAIN_DELAY )
115-
116-
117100def _make_dummy_llm () -> LLM :
118101 """Create a dummy LLM that should never be called directly."""
119102 return LLM (model = "acp-managed" )
@@ -215,8 +198,12 @@ def __init__(self) -> None:
215198 self .on_token : Any = None # ConversationTokenCallbackType | None
216199 # Telemetry state from UsageUpdate (persists across turns)
217200 self ._last_cost : float = 0.0 # last cumulative cost seen
218- self ._context_window : int = 0 # context window size from ACP
219- self ._llm_ref : Any = None # reference to the sentinel LLM
201+ self ._last_cost_by_session : dict [str , float ] = {}
202+ self ._context_window : int = 0 # last context window seen
203+ self ._context_window_by_session : dict [str , int ] = {}
204+ # Per-turn synchronization for UsageUpdate notifications.
205+ self ._turn_usage_updates : dict [str , Any ] = {}
206+ self ._usage_received : dict [str , asyncio .Event ] = {}
220207 # Fork session state for ask_agent() — guarded by _fork_lock to
221208 # prevent concurrent ask_agent() calls from colliding.
222209 self ._fork_lock = threading .Lock ()
@@ -228,9 +215,27 @@ def reset(self) -> None:
228215 self .accumulated_thoughts .clear ()
229216 self .accumulated_tool_calls .clear ()
230217 self .on_token = None
218+ self ._turn_usage_updates .clear ()
219+ self ._usage_received .clear ()
231220 # Note: telemetry state (_last_cost, _context_window, etc.)
232221 # is intentionally NOT cleared — it accumulates across turns.
233222
223+ def prepare_usage_sync (self , session_id : str ) -> asyncio .Event :
224+ """Prepare per-turn UsageUpdate synchronization for a session."""
225+ event = asyncio .Event ()
226+ self ._usage_received [session_id ] = event
227+ self ._turn_usage_updates .pop (session_id , None )
228+ return event
229+
230+ def get_turn_usage_update (self , session_id : str ) -> Any :
231+ """Return the latest UsageUpdate observed for the current turn."""
232+ return self ._turn_usage_updates .get (session_id )
233+
234+ def pop_turn_usage_update (self , session_id : str ) -> Any :
235+ """Consume per-turn UsageUpdate synchronization state for a session."""
236+ self ._usage_received .pop (session_id , None )
237+ return self ._turn_usage_updates .pop (session_id , None )
238+
234239 # -- Client protocol methods ------------------------------------------
235240
236241 async def session_update (
@@ -261,14 +266,13 @@ async def session_update(
261266 if isinstance (update .content , TextContentBlock ):
262267 self .accumulated_thoughts .append (update .content .text )
263268 elif isinstance (update , UsageUpdate ):
264- # Update context window size
269+ # Store the update for step()/ask_agent() to process in one place.
265270 self ._context_window = update .size
266- # Record incremental cost
267- if update .cost is not None and self ._llm_ref is not None :
268- delta = update .cost .amount - self ._last_cost
269- if delta > 0 :
270- self ._llm_ref .metrics .add_cost (delta )
271- self ._last_cost = update .cost .amount
271+ self ._context_window_by_session [session_id ] = update .size
272+ self ._turn_usage_updates [session_id ] = update
273+ event = self ._usage_received .get (session_id )
274+ if event is not None :
275+ event .set ()
272276 elif isinstance (update , ToolCallStart ):
273277 self .accumulated_tool_calls .append (
274278 {
@@ -458,14 +462,24 @@ def _record_usage(
458462 response : PromptResponse | None ,
459463 session_id : str ,
460464 elapsed : float | None = None ,
465+ usage_update : UsageUpdate | None = None ,
461466 ) -> None :
462- """Record token usage, latency, and notify stats callback from a PromptResponse .
467+ """Record cost, token usage, latency, and notify stats callback once .
463468
464469 Args:
465470 response: The ACP PromptResponse (may carry a ``usage`` field).
466471 session_id: Session identifier used as the response_id for metrics.
467472 elapsed: Wall-clock seconds for this prompt round-trip (optional).
473+ usage_update: The synchronized ACP UsageUpdate for this turn, if any.
468474 """
475+ if usage_update is not None and usage_update .cost is not None :
476+ last_cost = self ._client ._last_cost_by_session .get (session_id , 0.0 )
477+ delta = usage_update .cost .amount - last_cost
478+ if delta > 0 :
479+ self .llm .metrics .add_cost (delta )
480+ self ._client ._last_cost_by_session [session_id ] = usage_update .cost .amount
481+ self ._client ._last_cost = usage_update .cost .amount
482+
469483 if response is not None and response .usage is not None :
470484 usage = response .usage
471485 self .llm .metrics .add_token_usage (
@@ -474,7 +488,9 @@ def _record_usage(
474488 cache_read_tokens = usage .cached_read_tokens or 0 ,
475489 cache_write_tokens = usage .cached_write_tokens or 0 ,
476490 reasoning_tokens = usage .thought_tokens or 0 ,
477- context_window = self ._client ._context_window ,
491+ context_window = self ._client ._context_window_by_session .get (
492+ session_id , self ._client ._context_window
493+ ),
478494 response_id = session_id ,
479495 )
480496
@@ -568,7 +584,6 @@ def init_state(
568584 def _start_acp_server (self , state : ConversationState ) -> None :
569585 """Start the ACP subprocess and initialize the session."""
570586 client = _OpenHandsACPBridge ()
571- client ._llm_ref = self .llm
572587 self ._client = client
573588
574589 # Build environment: inherit current env + ACP extras
@@ -712,11 +727,22 @@ def step(
712727 try :
713728
714729 async def _prompt () -> PromptResponse :
730+ usage_sync = self ._client .prepare_usage_sync (self ._session_id or "" )
715731 response = await self ._conn .prompt (
716732 [text_block (user_message )],
717733 self ._session_id ,
718734 )
719- await _drain_notifications ()
735+ if self ._client .get_turn_usage_update (self ._session_id or "" ) is None :
736+ try :
737+ await asyncio .wait_for (
738+ usage_sync .wait (), timeout = _USAGE_UPDATE_TIMEOUT
739+ )
740+ except TimeoutError :
741+ logger .warning (
742+ "UsageUpdate not received within %.1fs for session %s" ,
743+ _USAGE_UPDATE_TIMEOUT ,
744+ self ._session_id ,
745+ )
720746 return response
721747
722748 # Send prompt to ACP server with retry logic for connection errors.
@@ -736,9 +762,8 @@ async def _prompt() -> PromptResponse:
736762 response = self ._executor .run_async (
737763 _prompt , timeout = self .acp_prompt_timeout
738764 )
739- break # Success, exit retry loop
765+ break
740766 except TimeoutError :
741- # Timeout is handled separately below, don't retry
742767 raise
743768 except _RETRIABLE_CONNECTION_ERRORS as e :
744769 if attempt < max_retries :
@@ -754,17 +779,22 @@ async def _prompt() -> PromptResponse:
754779 e ,
755780 )
756781 time .sleep (delay )
757- # Reset accumulators for retry (partial state may be stale)
758782 self ._client .reset ()
759783 self ._client .on_token = on_token
760784 else :
761- # Max retries exceeded
762785 raise
763786
764787 elapsed = time .monotonic () - t0
765788 logger .info ("ACP prompt returned in %.1fs" , elapsed )
766789
767- self ._record_usage (response , self ._session_id or "" , elapsed = elapsed )
790+ session_id = self ._session_id or ""
791+ usage_update = self ._client .pop_turn_usage_update (session_id )
792+ self ._record_usage (
793+ response ,
794+ session_id ,
795+ elapsed = elapsed ,
796+ usage_update = usage_update ,
797+ )
768798
769799 # Emit ACPToolCallEvents for each accumulated tool call
770800 for tc in self ._client .accumulated_tool_calls :
@@ -911,15 +941,32 @@ async def _fork_and_prompt() -> str:
911941 client ._fork_accumulated_text .clear ()
912942 try :
913943 fork_t0 = time .monotonic ()
944+ usage_sync = client .prepare_usage_sync (fork_session_id )
914945 response = await self ._conn .prompt (
915946 [text_block (question )],
916947 fork_session_id ,
917948 )
918- await _drain_notifications ()
949+ if client .get_turn_usage_update (fork_session_id ) is None :
950+ try :
951+ await asyncio .wait_for (
952+ usage_sync .wait (), timeout = _USAGE_UPDATE_TIMEOUT
953+ )
954+ except TimeoutError :
955+ logger .warning (
956+ "UsageUpdate not received within %.1fs for fork session %s" ,
957+ _USAGE_UPDATE_TIMEOUT ,
958+ fork_session_id ,
959+ )
919960 fork_elapsed = time .monotonic () - fork_t0
920961
921962 result = "" .join (client ._fork_accumulated_text )
922- self ._record_usage (response , fork_session_id , elapsed = fork_elapsed )
963+ usage_update = client .pop_turn_usage_update (fork_session_id )
964+ self ._record_usage (
965+ response ,
966+ fork_session_id ,
967+ elapsed = fork_elapsed ,
968+ usage_update = usage_update ,
969+ )
923970 return result
924971 finally :
925972 client ._fork_session_id = None
0 commit comments