-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathattempt.py
More file actions
251 lines (220 loc) · 11.1 KB
/
attempt.py
File metadata and controls
251 lines (220 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import hashlib
import logging
import time
import uuid
from dataclasses import InitVar
from typing import Any, Dict, List, Optional
from sqlalchemy import JSON, Float, Integer, String, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapped, mapped_column
from agentlightning.types import Attempt
from .base import AttemptStatusUpdateMessage, SqlAlchemyBase
logger = logging.getLogger(__name__)
def _generate_attempt_id() -> str:
"""We don't need that long because attempts are limited to rollouts."""
short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:8]
return "at-" + short_id
class AttemptInDB(SqlAlchemyBase):
__tablename__ = "attempts"
rollout_id: Mapped[str] = mapped_column(String, nullable=False)
attempt_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_attempt_id)
sequence_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False)
end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None)
status: Mapped[str] = mapped_column(String, default="preparing", nullable=False)
worker_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None)
last_heartbeat_time: Mapped[Optional[float]] = mapped_column(Float, nullable=False, default_factory=time.time)
attempt_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None)
# addition columns for processing
max_duration: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, default=None
) # maximum duration allowed for this attempt in seconds
max_heartbeat_interval: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, default=None
) # maximum allowed heartbeat interval in seconds
version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
__mapper_args__ = {
"version_id_col": version_id,
}
def is_unresponsive(self, current_time: float) -> bool:
"""Check if the attempt is unresponsive based on the last heartbeat time and max_heartbeat_interval."""
if self.max_heartbeat_interval is None:
return False
if self.last_heartbeat_time is None:
return False
return (current_time - self.last_heartbeat_time) > self.max_heartbeat_interval
def is_timed_out(self, current_time: float) -> bool:
"""Check if the attempt has timed out based on the start time and max_duration."""
if self.max_duration is None:
return False
return (current_time - self.start_time) > self.max_duration
def as_attempt(self) -> Attempt:
return Attempt(
**self.model_dump(
exclude={"max_duration", "max_heartbeat_interval", "version_id"},
mapper={"metadata": lambda obj: obj.attempt_metadata}, # type: ignore
)
)
def _validate_status_message(self, msg: Dict[str, Any]) -> None:
"""This function validates the status update message from caller.
Raises ValueError if the message is invalid.
"""
if "event" not in msg:
raise ValueError("Status update message must contain 'event' field.")
if "timestamp" not in msg:
msg["timestamp"] = time.time()
if msg["event"] not in [
"user_update", # user update attempt status via dbstore.update_attempt()
"span_received", # new span received
"single_step_timeout", # single step timeout detected (from last span heartbeat)
"overall_timeout", # overall timeout detected
]:
raise ValueError(f"Unsupported event type: {msg['event']}")
if msg["event"] == "user_update" and "new_status" not in msg:
raise ValueError("User update event must contain 'new_status' field.")
def get_finished_statuses(self) -> List[str]:
"""This function returns the list of statuses that are considered finished."""
return [
"succeeded",
"failed",
"timeout",
]
def update_status(self, msg: Dict[str, Any]) -> Optional[AttemptStatusUpdateMessage]:
"""This function updates the status of the attempt based on the event.
Args:
msg: A dictionary containing the status update message. It must contain an "event" field, and optionally a "new_status" field.
More details about the message format can be found in the `_validate_status_message`() method.
current_time: The current time to use for updating timestamps. If None, uses time.time().
Returns:
A dictionary containing the status update message: {"event": "attempt_status_updated", "old_status": old_status, "new_status": new_status}.
IF no meaningful status update is performed, returns None.
Raises:
ValueError: If the event is not recognized or the status transition is invalid.
NotImplementedError: If the event handling is not implemented for the current status.
RuntimeError: If the new status is not set after processing the event.
"""
self._validate_status_message(msg)
event = msg["event"]
current_time = msg.get("timestamp", time.time())
old_status = self.status
new_status = msg.get("new_status", None)
# Step 1: Determine the new status based on the event and current status
if event == "user_update":
if not new_status:
raise ValueError("new_status must be provided for user_update event.")
elif event == "span_received":
self.last_heartbeat_time = current_time
if old_status in ["preparing", "unresponsive", "running"]:
new_status = "running"
elif old_status in self.get_finished_statuses():
logger.warning(
f"Span received after attempt is already in status {self.status}. No status update performed."
)
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status {old_status}.")
elif event == "single_step_timeout":
if old_status in [
"preparing",
"running",
]:
new_status = "unresponsive"
else:
logger.warning(
f"Single step timeout detected but attempt is in status {self.status}. No status update performed."
)
return # no further status update needed
elif event == "overall_timeout":
if old_status not in self.get_finished_statuses():
new_status = "timeout"
else:
logger.warning(
f"Overall timeout detected but attempt is in status {self.status}. No status update performed."
)
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status update.")
# Step 2: Update the status
if not new_status:
raise RuntimeError(
f"new_status should not be {new_status} after processing event for {event} on status {old_status}."
)
if new_status == old_status:
return # no status change
if new_status in self.get_finished_statuses():
# when attempt is finished, set end_time
self.end_time = current_time
self.status = new_status
# Step 3: Return the status update info for further processing
return AttemptStatusUpdateMessage(
attempt_id=self.attempt_id,
rollout_id=self.rollout_id,
timestamp=current_time,
old_status=old_status,
new_status=new_status,
)
@classmethod
async def get_latest_attempt_for_rollout(
cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str
) -> Optional[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.desc()).limit(1)
)
attempt_obj = result.one_or_none()
if attempt_obj is None:
return None
return attempt_obj.as_attempt()
@classmethod
async def get_attempts_for_rollout(
cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str
) -> List[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.asc())
)
return [attempt.as_attempt() for attempt in result.all()]
class SpanSeqIdInDB(SqlAlchemyBase):
__tablename__ = "span_sequence"
rollout_id: Mapped[str] = mapped_column(nullable=False, primary_key=True)
# FIXME InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting
# attempt_id: Mapped[str] = mapped_column(nullable=False)
attempt_id: InitVar[str] # not mapped column, just for type hinting
current_sequence: Mapped[int] = mapped_column(default=1, nullable=False)
# Versioning for optimistic concurrency control
version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
__mapper_args__ = {
"version_id_col": version_id,
# "primary_key": [rollout_id, attempt_id],
# "primary_key": [rollout_id],
}
@classmethod
async def get_next_sequence_id(
cls: type[SpanSeqIdInDB],
session_factory: async_sessionmaker[AsyncSession],
rollout_id: str,
attempt_id: str,
external_seq_id: Optional[int] = None,
) -> int:
"""Get the next sequence ID with retries to handle race conditions.
IF external_seq_id is provided and is greater than current_sequence, set current_sequence to external_seq_id.
"""
async with session_factory() as session:
async with session.begin():
seq_obj = await session.get(cls, rollout_id)
# seq_obj = await session.get(cls, [rollout_id, attempt_id])
if seq_obj is None:
raise ValueError(f"Rollout {rollout_id} not found")
else:
current_seq = (
external_seq_id
if external_seq_id is not None and external_seq_id > seq_obj.current_sequence
else seq_obj.current_sequence
)
seq_obj.current_sequence = current_seq + 1
await session.flush()
return current_seq