Skip to content

Commit 968c2cf

Browse files
committed
feat: enhance TinkerScript functionality with improved engine status handling and episode management
1 parent 5cc7297 commit 968c2cf

File tree

7 files changed

+212
-103
lines changed

7 files changed

+212
-103
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def init_workers(self):
444444
)
445445

446446
def _update_interchange_server_status_flag(self, status: str):
447-
# if interchange server is enabled, change engine status to ROLLING
448447
if self.config.ajet.enable_experimental_interchange_server:
449448
if self.config.ajet.enable_tinkerscript_mode:
450449
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status

ajet/task_runner/tinkerscript_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from loguru import logger
1717
from ajet import Workflow
1818

19+
DEBUG = False
20+
1921
context = zmq.Context()
2022
atexit.register(context.term)
21-
DEBUG = True
2223

2324
class TinkerScriptRunner(BaseAgentRunner):
2425

@@ -33,12 +34,18 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s
3334
openai_api_key=openai_api_key,
3435
zmq_listen_result_addr=zmq_listen_result_addr,
3536
)
36-
logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}")
37+
if DEBUG: logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}")
3738

3839
# begin wait for result
3940
zmq_socket = zmq.Context().socket(zmq.REP)
4041
zmq_socket.bind(zmq_listen_result_addr)
42+
43+
# <wait for>:
44+
# <from_sourcefile>: ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py
45+
# <from_code>: socket.send_string(workflow_output.model_dump_json())
46+
# <expect>: workflow_output: WorkflowOutput
4147
message = zmq_socket.recv_string()
48+
4249
logger.success(f"Received workflow output for episode {episode_uuid}")
4350
zmq_socket.send_string("ack")
4451
zmq_socket.close()

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
157157
if enable_tinkerscript_mode:
158158
assert shared_mem_dict is not None
159159
assert shared_mem_dict_lock is not None
160-
if shared_mem_dict['engine_status'] != "ROLLING":
161-
logger.error(f"The server is not in ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.")
162-
raise HTTPException(status_code=503, detail="The server is not in ROLLING status, cannot accept new requests.")
160+
if shared_mem_dict['engine_status'] != "ENGINE.ROLLING":
161+
logger.error(f"The server is not in ENGINE.ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.")
162+
raise HTTPException(status_code=503, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.")
163163
if (f"episodes-{episode_uuid}") not in shared_mem_dict:
164164
raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.")
165165
# update activate timestamp

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class TinkerScriptClient(object):
2525
def __init__(self, server_url: str):
2626
self.server_url = server_url
2727
self.client_uuid = str(uuid.uuid4())
28+
self.previous_warning_time = 0
2829

2930

3031
def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
@@ -59,8 +60,18 @@ def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAnd
5960
episode_uuid=episode_uuid
6061
)
6162
else:
62-
logger.info(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...")
63-
time.sleep(5)
63+
need_wait_scenarios =[
64+
"Engine is syncing weights",
65+
"No available episodes to claim.",
66+
]
67+
if any(scenario in data.fail_cause for scenario in need_wait_scenarios):
68+
if time.time() - self.previous_warning_time > 60:
69+
logger.info(f"{data.fail_cause}. Retrying in 30s...")
70+
self.previous_warning_time = time.time()
71+
time.sleep(30)
72+
else:
73+
logger.warning(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...")
74+
time.sleep(5)
6475
except Exception as e:
6576
logger.error(f"Error claiming episode: {e}. Retrying in 5s...")
6677
time.sleep(5)
@@ -98,6 +109,11 @@ def sync_train_config(self, agent_jet_job: AgentJetJob):
98109
Sync training configuration to the TinkerScript server.
99110
This sends the AgentJetJob config as YAML to the remote server.
100111
"""
112+
# try get init status
113+
current_status = self.get_engine_status()
114+
if current_status != "ENGINE.OFFLINE":
115+
raise RuntimeError(f"Cannot sync train config when engine is NOT ENGINE.OFFLINE. (current status: {current_status})")
116+
101117
try:
102118
config_dict = agent_jet_job.config.to_dict()
103119
yaml_str = yaml.safe_dump(config_dict, sort_keys=False)
@@ -121,6 +137,12 @@ def start_engine(self):
121137
This triggers the server to begin the training process.
122138
Polls until engine status is "ENGINE.ROLLING".
123139
"""
140+
# try get init status
141+
current_status = self.get_engine_status()
142+
if current_status != "ENGINE.OFFLINE":
143+
raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})")
144+
145+
# Send start engine request
124146
try:
125147
resp = httpx.post(
126148
f"{self.server_url}/start_engine",
@@ -139,8 +161,17 @@ def start_engine(self):
139161
raise
140162

141163
# Poll until engine status is "ENGINE.ROLLING"
164+
self._wait_until_avail()
165+
logger.success("Training engine is now ROLLING and ready.")
166+
167+
def _wait_until_avail(self):
168+
"""
169+
Poll engine status until it reaches ENGINE.ROLLING state.
170+
Reports status every 5 seconds while waiting.
171+
"""
142172
logger.info("Polling engine status until ENGINE.ROLLING...")
143173
last_report_time = time.time()
174+
init_poll_time = last_report_time
144175

145176
while True:
146177
try:
@@ -149,7 +180,7 @@ def start_engine(self):
149180

150181
# Report status every 5 seconds
151182
if current_time - last_report_time >= 5:
152-
logger.info(f"Current engine status: {current_status}")
183+
logger.info(f"Current engine status (already waited {current_time - init_poll_time:.1f}s): {current_status}")
153184
last_report_time = current_time
154185

155186
# Check if engine has reached the desired status
@@ -210,3 +241,22 @@ def get_episode_buffer(self) -> List[EpisodeStatus]:
210241
except Exception as e:
211242
logger.error(f"Error getting episode buffer: {e}")
212243
return []
244+
245+
def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob):
246+
"""
247+
Automatically sync training configuration and start the engine if needed.
248+
This checks the current engine status and performs actions accordingly.
249+
"""
250+
current_status = self.get_engine_status()
251+
if current_status == "ENGINE.OFFLINE":
252+
logger.info("Engine is OFFLINE. Syncing train config and starting engine...")
253+
self.sync_train_config(agent_jet_job)
254+
self.start_engine()
255+
elif current_status == "ENGINE.ROLLING":
256+
logger.info("Engine is already ROLLING. No action needed.")
257+
elif current_status == "ENGINE.BOOTING":
258+
logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...")
259+
self._wait_until_avail()
260+
logger.success("Training engine is now ROLLING and ready.")
261+
else:
262+
raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}")

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py

Lines changed: 111 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
UpdateEngineStatusRequest,
2929
)
3030

31-
DEBUG = True
31+
DEBUG = False
3232

3333
def register_enable_tinkerscript_mode_routes(
3434
app,
@@ -43,6 +43,84 @@ def register_enable_tinkerscript_mode_routes(
4343
if 'unclaimed_episodes' not in shared_mem_dict:
4444
shared_mem_dict['unclaimed_episodes'] = []
4545

46+
def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]:
47+
result = []
48+
current_time = time.time()
49+
50+
for k, v in shared_mem_dict.items():
51+
if k.startswith("episodes-"):
52+
es:EpisodeStatus = v
53+
if es.episode_status == "claimed":
54+
if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout:
55+
result.append(es.episode_uuid)
56+
57+
for episode_uuid in result:
58+
_revert_episode_to_unclaimed(episode_uuid)
59+
60+
return result
61+
62+
def _revert_episode_to_unclaimed(episode_uuid: str):
63+
with shared_mem_dict_lock:
64+
# check status again, because other thread may have changed it
65+
if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed":
66+
return
67+
68+
# revert
69+
logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.")
70+
if f"episodes-{episode_uuid}" in shared_mem_dict:
71+
es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"]
72+
es.episode_status = "registered"
73+
es.client_uuid = ""
74+
es.latest_activity_timestamp = time.time()
75+
es.allow_discard_timeout = -1
76+
shared_mem_dict[f"episodes-{episode_uuid}"] = es
77+
shared_mem_dict['unclaimed_episodes'] += [episode_uuid]
78+
79+
80+
async def register_episode_ready_listener():
81+
while True:
82+
read_all_episode_status()
83+
await asyncio.sleep(10) # check every 10 seconds
84+
find_claimed_episodes_that_need_to_be_unclaimed()
85+
86+
87+
def read_all_episode_status() -> Optional[EpisodeStatus]:
88+
print_buffer = []
89+
group_by_status = {}
90+
91+
for k, v in shared_mem_dict.items():
92+
if k.startswith("episodes-"):
93+
es:EpisodeStatus = v
94+
if es.episode_status not in group_by_status:
95+
group_by_status[es.episode_status] = []
96+
group_by_status[es.episode_status].append(es)
97+
98+
for status, es_list in group_by_status.items():
99+
print_buffer.append(f"--- {status} (time since last activity) ---")
100+
in_line_buffer = ""
101+
for es in es_list:
102+
time_since_last_activity = time.time() - es.latest_activity_timestamp
103+
in_line_buffer += f"{es.episode_uuid[:6]}({time_since_last_activity:.1f}s)\t"
104+
print_buffer.append(in_line_buffer)
105+
106+
print_buffer_str = "\n".join(print_buffer)
107+
logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}]")
108+
if print_buffer:
109+
logger.info(f"Current episode statuses:\n{print_buffer_str}")
110+
else:
111+
logger.info(f"Current episode statuses: [NA]")
112+
113+
return None
114+
115+
116+
# hiefwu1(15.1s ago) hiefwu2(20.3s ago) hiefwu3(5.0s ago)
117+
118+
119+
120+
# --------------------------------------------------------------------
121+
# -------------------------- fastapi routes --------------------------
122+
# --------------------------------------------------------------------
123+
46124
@app.post("/sync_train_config")
47125
async def sync_train_config(req: SyncTrainConfigRequest):
48126
"""
@@ -120,7 +198,7 @@ async def start_engine():
120198

121199
# Setup environment variables
122200
exp_config['ajet']['interchange_server']['already_started'] = True
123-
exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT"))
201+
exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore
124202
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
125203

126204
# Start ray if not already started
@@ -163,11 +241,12 @@ async def start_engine():
163241

164242

165243
# --- engine status ---
166-
shared_mem_dict['engine_status'] = "ENGINE.OFF"
244+
shared_mem_dict['engine_status'] = "ENGINE.OFFLINE"
167245
@app.post("/update_engine_status", response_model=BoolResponse)
168246
async def update_engine_status(req: UpdateEngineStatusRequest):
247+
"""Update the current engine status."""
169248
if req.engine_status not in [
170-
"ENGINE.OFF",
249+
"ENGINE.OFFLINE",
171250
"ENGINE.BOOTING",
172251
"ENGINE.ROLLING",
173252
"ENGINE.WEIGHT_SYNCING",
@@ -180,14 +259,15 @@ async def update_engine_status(req: UpdateEngineStatusRequest):
180259

181260
@app.get("/get_engine_status")
182261
async def get_engine_status():
262+
"""Get the current engine status."""
183263
status = shared_mem_dict['engine_status']
184264
return {"engine_status": status}
185265

186266

187267
# --- episode status ---
188268
@app.post("/register_episode", response_model=BoolResponse)
189269
async def register_episode(req: RegisterEpisodeRequest):
190-
270+
"""(From task_runner) Register a new episode as ready to roll."""
191271
episode_uuid = req.episode_uuid
192272
es = EpisodeStatus(
193273
episode_uuid=req.episode_uuid,
@@ -210,8 +290,30 @@ async def register_episode(req: RegisterEpisodeRequest):
210290

211291
@app.post("/claim_episode", response_model=ClaimEpisodeResponse)
212292
async def claim_episode(req: ClaimEpisodeRequest):
293+
"""(From client) Claim an available episode to rollout."""
213294
find_claimed_episodes_that_need_to_be_unclaimed()
214295

296+
engine_status = shared_mem_dict['engine_status']
297+
if engine_status != "ENGINE.ROLLING":
298+
fail_cause = f"Engine not ready. Current status: [{engine_status}]."
299+
advise = ""
300+
if engine_status == "ENGINE.OFFLINE":
301+
advise = "Please start the engine first. Please use one of the client to run `client.sync_train_config() + client.start_engine()` to start the engine."
302+
elif engine_status == "ENGINE.BOOTING":
303+
advise = "Please wait until the engine is fully booted. Try again (maybe 1 minute) later."
304+
elif engine_status == "ENGINE.WEIGHT_SYNCING":
305+
advise = "Engine is syncing weights. Try again (maybe 1 minute) later."
306+
elif engine_status == "ENGINE.WEIGHT_EXPORTING":
307+
advise = "Engine is exporting weights (fsdp -> hf safetensor). Try again (maybe 1 minute) later."
308+
return ClaimEpisodeResponse(
309+
success=False,
310+
client_uuid=req.client_uuid,
311+
episode_uuid="",
312+
openai_base_url="",
313+
openai_api_key="",
314+
fail_cause=fail_cause + " " + advise,
315+
)
316+
215317
with shared_mem_dict_lock:
216318
if len(shared_mem_dict['unclaimed_episodes']) <= 0:
217319
return ClaimEpisodeResponse(
@@ -248,41 +350,6 @@ async def claim_episode(req: ClaimEpisodeRequest):
248350
)
249351

250352

251-
def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]:
252-
result = []
253-
current_time = time.time()
254-
255-
for k, v in shared_mem_dict.items():
256-
if k.startswith("episodes-"):
257-
es:EpisodeStatus = v
258-
if es.episode_status == "claimed":
259-
if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout:
260-
result.append(es.episode_uuid)
261-
262-
for episode_uuid in result:
263-
_revert_episode_to_unclaimed(episode_uuid)
264-
265-
return result
266-
267-
268-
def _revert_episode_to_unclaimed(episode_uuid: str):
269-
with shared_mem_dict_lock:
270-
# check status again, because other thread may have changed it
271-
if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed":
272-
return
273-
274-
# revert
275-
logger.info(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.")
276-
if f"episodes-{episode_uuid}" in shared_mem_dict:
277-
es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"]
278-
es.episode_status = "registered"
279-
es.client_uuid = ""
280-
es.latest_activity_timestamp = time.time()
281-
es.allow_discard_timeout = -1
282-
shared_mem_dict[f"episodes-{episode_uuid}"] = es
283-
shared_mem_dict['unclaimed_episodes'] += [episode_uuid]
284-
285-
286353
@app.post("/end_episode", response_model=EndEpisodeResponse)
287354
async def end_episode(req: EndEpisodeRequest):
288355
# receive workflow output data
@@ -312,6 +379,10 @@ async def end_episode(req: EndEpisodeRequest):
312379
for _ in range(5): # max 5 minutes wait
313380
try:
314381
if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.")
382+
# <wait for>:
383+
# <from_sourcefile>: ajet/task_runner/tinkerscript_runner.py
384+
# <from_code>: zmq_socket.send_string("ack")
385+
# <expect>: "ack"
315386
result_str = socket.recv_string()
316387
break
317388
except zmq.Again as e:
@@ -345,9 +416,4 @@ async def get_episode_buffer():
345416
return EpisodeBufferResponse(buffer=result)
346417

347418

348-
349-
async def register_episode_ready_listener():
350-
pass
351-
352-
353419
return app, register_episode_ready_listener()

0 commit comments

Comments
 (0)