-
Notifications
You must be signed in to change notification settings - Fork 217
Expand file tree
/
Copy pathstate.py
More file actions
508 lines (434 loc) · 19.5 KB
/
state.py
File metadata and controls
508 lines (434 loc) · 19.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# state.py
import json
from collections.abc import Sequence
from enum import Enum
from pathlib import Path
from typing import Any, Self
from pydantic import Field, PrivateAttr, model_validator
from openhands.sdk.agent.base import AgentBase
from openhands.sdk.conversation.conversation_stats import ConversationStats
from openhands.sdk.conversation.event_store import EventLog
from openhands.sdk.conversation.fifo_lock import FIFOLock
from openhands.sdk.conversation.persistence_const import BASE_STATE, EVENTS_DIR
from openhands.sdk.conversation.secret_registry import SecretRegistry
from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID
from openhands.sdk.event import ActionEvent, ObservationEvent, UserRejectObservation
from openhands.sdk.event.base import Event
from openhands.sdk.event.types import EventID
from openhands.sdk.io import FileStore, InMemoryFileStore, LocalFileStore
from openhands.sdk.logger import get_logger
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
from openhands.sdk.security.confirmation_policy import (
ConfirmationPolicyBase,
NeverConfirm,
)
from openhands.sdk.utils.cipher import Cipher
from openhands.sdk.utils.deprecation import warn_deprecated
from openhands.sdk.utils.models import OpenHandsModel
from openhands.sdk.workspace.base import BaseWorkspace
logger = get_logger(__name__)
class ConversationExecutionStatus(str, Enum):
"""Enum representing the current execution state of the conversation."""
IDLE = "idle" # Conversation is ready to receive tasks
RUNNING = "running" # Conversation is actively processing
PAUSED = "paused" # Conversation execution is paused by user
WAITING_FOR_CONFIRMATION = (
"waiting_for_confirmation" # Conversation is waiting for user confirmation
)
FINISHED = "finished" # Conversation has completed the current task
ERROR = "error" # Conversation encountered an error (optional for future use)
STUCK = "stuck" # Conversation is stuck in a loop or unable to proceed
DELETING = "deleting" # Conversation is in the process of being deleted
def is_terminal(self) -> bool:
"""Check if this status represents a terminal state.
Terminal states indicate the run has completed and the agent is no longer
actively processing. These are: FINISHED, ERROR, STUCK.
Note: IDLE is NOT a terminal state - it's the initial state of a conversation
before any run has started. Including IDLE would cause false positives when
the WebSocket delivers the initial state update during connection.
Returns:
True if this is a terminal status, False otherwise.
"""
return self in (
ConversationExecutionStatus.FINISHED,
ConversationExecutionStatus.ERROR,
ConversationExecutionStatus.STUCK,
)
class ConversationState(OpenHandsModel):
# ===== Public, validated fields =====
id: ConversationID = Field(description="Unique conversation ID")
agent: AgentBase = Field(
...,
description=(
"The agent running in the conversation. "
"This is persisted to allow resuming conversations and "
"check agent configuration to handle e.g., tool changes, "
"LLM changes, etc."
),
)
workspace: BaseWorkspace = Field(
...,
description=(
"Workspace used by the agent to execute commands and read/write files. "
"Not the process working directory."
),
)
persistence_dir: str | None = Field(
default="workspace/conversations",
description="Directory for persisting conversation state and events. "
"If None, conversation will not be persisted.",
)
max_iterations: int = Field(
default=500,
gt=0,
description="Maximum number of iterations the agent can "
"perform in a single run.",
)
stuck_detection: bool = Field(
default=True,
description="Whether to enable stuck detection for the agent.",
)
# Enum-based state management
execution_status: ConversationExecutionStatus = Field(
default=ConversationExecutionStatus.IDLE
)
confirmation_policy: ConfirmationPolicyBase = NeverConfirm()
security_analyzer: SecurityAnalyzerBase | None = Field(
default=None,
description="Optional security analyzer to evaluate action risks.",
)
activated_knowledge_skills: list[str] = Field(
default_factory=list,
description="List of activated knowledge skills name",
)
# Hook-blocked actions: action_id -> blocking reason
blocked_actions: dict[str, str] = Field(
default_factory=dict,
description="Actions blocked by PreToolUse hooks, keyed by action ID",
)
# Hook-blocked messages: message_id -> blocking reason
blocked_messages: dict[str, str] = Field(
default_factory=dict,
description="Messages blocked by UserPromptSubmit hooks, keyed by message ID",
)
# Track the most recent user MessageEvent ID to avoid event log scans.
last_user_message_id: EventID | None = Field(
default=None,
description=(
"Most recent user MessageEvent id for hook block checks. "
"Updated when user messages are emitted so Agent.step can pop "
"blocked_messages without scanning the event log. If None, "
"hook-blocked checks are skipped (legacy conversations)."
),
)
# Conversation statistics for LLM usage tracking
stats: ConversationStats = Field(
default_factory=ConversationStats,
description="Conversation statistics for tracking LLM metrics",
)
# Secret registry for handling sensitive data
secret_registry: SecretRegistry = Field(
default_factory=SecretRegistry,
description="Registry for handling secrets and sensitive data",
)
# Agent-specific runtime state (simple dict for flexibility)
agent_state: dict[str, Any] = Field(
default_factory=dict,
description="Dictionary for agent-specific runtime state that persists across "
"iterations. Agents can store feature-specific state using string keys. "
"To trigger autosave, always reassign: "
"state.agent_state = {**state.agent_state, key: value}. "
"See https://docs.openhands.dev/sdk/guides/convo-persistence#how-state-persistence-works",
)
# ===== Private attrs (NOT Fields) =====
_fs: FileStore = PrivateAttr() # filestore for persistence
_events: EventLog = PrivateAttr() # now the storage for events
_cipher: Cipher | None = PrivateAttr(default=None) # cipher for secret encryption
_autosave_enabled: bool = PrivateAttr(
default=False
) # to avoid recursion during init
_on_state_change: ConversationCallbackType | None = PrivateAttr(
default=None
) # callback for state changes
_lock: FIFOLock = PrivateAttr(
default_factory=FIFOLock
) # FIFO lock for thread safety
@model_validator(mode="before")
@classmethod
def _handle_legacy_fields(cls, data: Any) -> Any:
"""Handle legacy field names for backward compatibility."""
if not isinstance(data, dict):
return data
# Handle legacy 'secrets_manager' field name
if "secrets_manager" in data:
warn_deprecated(
"ConversationState.secrets_manager",
deprecated_in="1.12.0",
removed_in="1.15.0",
details=(
"The 'secrets_manager' field has been renamed to "
"'secret_registry'. Please update your code to use "
"'secret_registry' instead."
),
stacklevel=4,
)
data["secret_registry"] = data.pop("secrets_manager")
return data
@property
def events(self) -> EventLog:
return self._events
@property
def env_observation_persistence_dir(self) -> str | None:
"""Directory for persisting environment observation files."""
if self.persistence_dir is None:
return None
return str(Path(self.persistence_dir) / "observations")
def set_on_state_change(self, callback: ConversationCallbackType | None) -> None:
"""Set a callback to be called when state changes.
Args:
callback: A function that takes an Event (ConversationStateUpdateEvent)
or None to remove the callback
"""
self._on_state_change = callback
# ===== Base snapshot helpers (same FileStore usage you had) =====
def _save_base_state(self, fs: FileStore) -> None:
"""
Persist base state snapshot (no events; events are file-backed).
If a cipher is configured, secrets will be encrypted. Otherwise, they
will be redacted (serialized as '**********').
"""
context = {"cipher": self._cipher} if self._cipher else None
# Warn if secrets exist but no cipher is configured
if not self._cipher and self.secret_registry.secret_sources:
logger.warning(
f"Saving conversation state without cipher - "
f"{len(self.secret_registry.secret_sources)} secret(s) will be "
"redacted and lost on restore. Consider providing a cipher to "
"preserve secrets."
)
payload = self.model_dump_json(exclude_none=True, context=context)
fs.write(BASE_STATE, payload)
# ===== Factory: open-or-create (no load/save methods needed) =====
@classmethod
def create(
cls: type["ConversationState"],
id: ConversationID,
agent: AgentBase,
workspace: BaseWorkspace,
persistence_dir: str | None = None,
max_iterations: int = 500,
stuck_detection: bool = True,
cipher: Cipher | None = None,
) -> "ConversationState":
"""Create a new conversation state or resume from persistence.
This factory method handles both new conversation creation and resumption
from persisted state.
**New conversation:**
The provided Agent is used directly. Pydantic validation happens via the
cls() constructor.
**Restored conversation:**
The provided Agent is validated against the persisted agent using
agent.load(). Tools must match (they may have been used in conversation
history), but all other configuration can be freely changed: LLM,
agent_context, condenser, system prompts, etc.
Args:
id: Unique conversation identifier
agent: The Agent to use (tools must match persisted on restore)
workspace: Working directory for agent operations
persistence_dir: Directory for persisting state and events
max_iterations: Maximum iterations per run
stuck_detection: Whether to enable stuck detection
cipher: Optional cipher for encrypting/decrypting secrets in
persisted state. If provided, secrets are encrypted when
saving and decrypted when loading. If not provided, secrets
are redacted (lost) on serialization.
Returns:
ConversationState ready for use
Raises:
ValueError: If conversation ID or tools mismatch on restore
ValidationError: If agent or other fields fail Pydantic validation
"""
file_store = (
LocalFileStore(persistence_dir, cache_limit_size=max_iterations)
if persistence_dir
else InMemoryFileStore()
)
try:
base_text = file_store.read(BASE_STATE)
except FileNotFoundError:
base_text = None
# ---- Resume path ----
if base_text:
# Use cipher context for decrypting secrets if provided
context = {"cipher": cipher} if cipher else None
state = cls.model_validate(json.loads(base_text), context=context)
# Restore the conversation with the same id
if state.id != id:
raise ValueError(
f"Conversation ID mismatch: provided {id}, "
f"but persisted state has {state.id}"
)
# Attach event log early so we can read history for tool verification
state._fs = file_store
state._events = EventLog(file_store, dir_path=EVENTS_DIR)
state._cipher = cipher
# Verify compatibility (agent class + tools)
agent.verify(state.agent, events=state._events)
# Commit runtime-provided values (may autosave)
state._autosave_enabled = True
state.agent = agent
state.workspace = workspace
state.max_iterations = max_iterations
# Note: stats are already deserialized from base_state.json above.
# Do NOT reset stats here - this would lose accumulated metrics.
logger.info(
f"Resumed conversation {state.id} from persistent storage.\n"
f"State: {state.model_dump(exclude={'agent'})}\n"
f"Agent: {state.agent.model_dump_succint()}"
)
return state
# ---- Fresh path ----
if agent is None:
raise ValueError(
"agent is required when initializing a new ConversationState"
)
state = cls(
id=id,
agent=agent,
workspace=workspace,
persistence_dir=persistence_dir,
max_iterations=max_iterations,
stuck_detection=stuck_detection,
)
state._fs = file_store
state._events = EventLog(file_store, dir_path=EVENTS_DIR)
state._cipher = cipher
state.stats = ConversationStats()
state._save_base_state(file_store) # initial snapshot
state._autosave_enabled = True
logger.info(
f"Created new conversation {state.id}\n"
f"State: {state.model_dump(exclude={'agent'})}\n"
f"Agent: {state.agent.model_dump_succint()}"
)
return state
# ===== Auto-persist base on public field changes =====
def __setattr__(self, name, value):
# Only autosave when:
# - autosave is enabled (set post-init)
# - the attribute is a *public field* (not a PrivateAttr)
# - we have a filestore to write to
_sentinel = object()
old = getattr(self, name, _sentinel)
super().__setattr__(name, value)
is_field = name in self.__class__.model_fields
autosave_enabled = getattr(self, "_autosave_enabled", False)
fs = getattr(self, "_fs", None)
if not (autosave_enabled and is_field and fs is not None):
return
if old is _sentinel or old != value:
try:
self._save_base_state(fs)
except Exception as e:
logger.exception("Auto-persist base_state failed", exc_info=True)
raise e
# Call state change callback if set
callback = getattr(self, "_on_state_change", None)
if callback is not None and old is not _sentinel:
try:
# Import here to avoid circular imports
from openhands.sdk.event.conversation_state import (
ConversationStateUpdateEvent,
)
# Create a ConversationStateUpdateEvent with the changed field
state_update_event = ConversationStateUpdateEvent(
key=name, value=value
)
callback(state_update_event)
except Exception:
logger.exception(
f"State change callback failed for field {name}", exc_info=True
)
def block_action(self, action_id: str, reason: str) -> None:
"""Persistently record a hook-blocked action."""
self.blocked_actions = {**self.blocked_actions, action_id: reason}
def pop_blocked_action(self, action_id: str) -> str | None:
"""Remove and return a hook-blocked action reason, if present."""
if action_id not in self.blocked_actions:
return None
updated = dict(self.blocked_actions)
reason = updated.pop(action_id)
self.blocked_actions = updated
return reason
def block_message(self, message_id: str, reason: str) -> None:
"""Persistently record a hook-blocked user message."""
self.blocked_messages = {**self.blocked_messages, message_id: reason}
def pop_blocked_message(self, message_id: str) -> str | None:
"""Remove and return a hook-blocked message reason, if present."""
if message_id not in self.blocked_messages:
return None
updated = dict(self.blocked_messages)
reason = updated.pop(message_id)
self.blocked_messages = updated
return reason
@staticmethod
def get_unmatched_actions(events: Sequence[Event]) -> list[ActionEvent]:
"""Find actions in the event history that don't have matching observations.
This method identifies ActionEvents that don't have corresponding
ObservationEvents or UserRejectObservations, which typically indicates
actions that are pending confirmation or execution.
Args:
events: List of events to search through
Returns:
List of ActionEvent objects that don't have corresponding observations,
in chronological order
"""
observed_action_ids = set()
unmatched_actions = []
# Search in reverse - recent events are more likely to be unmatched
for event in reversed(events):
if isinstance(event, (ObservationEvent, UserRejectObservation)):
observed_action_ids.add(event.action_id)
elif isinstance(event, ActionEvent):
# Only executable actions (validated) are considered pending
if event.action is not None and event.id not in observed_action_ids:
# Insert at beginning to maintain chronological order in result
unmatched_actions.insert(0, event)
return unmatched_actions
# ===== FIFOLock delegation methods =====
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
"""
Acquire the lock.
Args:
blocking: If True, block until lock is acquired. If False, return
immediately.
timeout: Maximum time to wait for lock (ignored if blocking=False).
-1 means wait indefinitely.
Returns:
True if lock was acquired, False otherwise.
"""
return self._lock.acquire(blocking=blocking, timeout=timeout)
def release(self) -> None:
"""
Release the lock.
Raises:
RuntimeError: If the current thread doesn't own the lock.
"""
self._lock.release()
def __enter__(self: Self) -> Self:
"""Context manager entry."""
self._lock.acquire()
return self
def __exit__(self, exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
"""Context manager exit."""
self._lock.release()
def locked(self) -> bool:
"""
Return True if the lock is currently held by any thread.
"""
return self._lock.locked()
def owned(self) -> bool:
"""
Return True if the lock is currently held by the calling thread.
"""
return self._lock.owned()