2828 UpdateEngineStatusRequest ,
2929)
3030
31- DEBUG = True
31+ DEBUG = False
3232
3333def register_enable_tinkerscript_mode_routes (
3434 app ,
@@ -43,6 +43,84 @@ def register_enable_tinkerscript_mode_routes(
4343 if 'unclaimed_episodes' not in shared_mem_dict :
4444 shared_mem_dict ['unclaimed_episodes' ] = []
4545
46+ def find_claimed_episodes_that_need_to_be_unclaimed () -> List [str ]:
47+ result = []
48+ current_time = time .time ()
49+
50+ for k , v in shared_mem_dict .items ():
51+ if k .startswith ("episodes-" ):
52+ es :EpisodeStatus = v
53+ if es .episode_status == "claimed" :
54+ if (current_time - es .latest_activity_timestamp ) > es .allow_discard_timeout :
55+ result .append (es .episode_uuid )
56+
57+ for episode_uuid in result :
58+ _revert_episode_to_unclaimed (episode_uuid )
59+
60+ return result
61+
62+ def _revert_episode_to_unclaimed (episode_uuid : str ):
63+ with shared_mem_dict_lock :
64+ # check status again, because other thread may have changed it
65+ if shared_mem_dict [f"episodes-{ episode_uuid } " ].episode_status != "claimed" :
66+ return
67+
68+ # revert
69+ logger .warning (f"Reverting episode { episode_uuid } to unclaimed due to client timeout." )
70+ if f"episodes-{ episode_uuid } " in shared_mem_dict :
71+ es :EpisodeStatus = shared_mem_dict [f"episodes-{ episode_uuid } " ]
72+ es .episode_status = "registered"
73+ es .client_uuid = ""
74+ es .latest_activity_timestamp = time .time ()
75+ es .allow_discard_timeout = - 1
76+ shared_mem_dict [f"episodes-{ episode_uuid } " ] = es
77+ shared_mem_dict ['unclaimed_episodes' ] += [episode_uuid ]
78+
79+
80+ async def register_episode_ready_listener ():
81+ while True :
82+ read_all_episode_status ()
83+ await asyncio .sleep (10 ) # check every 10 seconds
84+ find_claimed_episodes_that_need_to_be_unclaimed ()
85+
86+
87+ def read_all_episode_status () -> Optional [EpisodeStatus ]:
88+ print_buffer = []
89+ group_by_status = {}
90+
91+ for k , v in shared_mem_dict .items ():
92+ if k .startswith ("episodes-" ):
93+ es :EpisodeStatus = v
94+ if es .episode_status not in group_by_status :
95+ group_by_status [es .episode_status ] = []
96+ group_by_status [es .episode_status ].append (es )
97+
98+ for status , es_list in group_by_status .items ():
99+ print_buffer .append (f"--- { status } (time since last activity) ---" )
100+ in_line_buffer = ""
101+ for es in es_list :
102+ time_since_last_activity = time .time () - es .latest_activity_timestamp
103+ in_line_buffer += f"{ es .episode_uuid [:6 ]} ({ time_since_last_activity :.1f} s)\t "
104+ print_buffer .append (in_line_buffer )
105+
106+ print_buffer_str = "\n " .join (print_buffer )
107+ logger .info (f"Current engine status: [{ shared_mem_dict ['engine_status' ]} ]" )
108+ if print_buffer :
109+ logger .info (f"Current episode statuses:\n { print_buffer_str } " )
110+ else :
111+ logger .info (f"Current episode statuses: [NA]" )
112+
113+ return None
114+
115+
116+ # hiefwu1(15.1s ago) hiefwu2(20.3s ago) hiefwu3(5.0s ago)
117+
118+
119+
120+ # --------------------------------------------------------------------
121+ # -------------------------- fastapi routes --------------------------
122+ # --------------------------------------------------------------------
123+
46124 @app .post ("/sync_train_config" )
47125 async def sync_train_config (req : SyncTrainConfigRequest ):
48126 """
@@ -120,7 +198,7 @@ async def start_engine():
120198
121199 # Setup environment variables
122200 exp_config ['ajet' ]['interchange_server' ]['already_started' ] = True
123- exp_config ['ajet' ]['interchange_server' ]['interchange_server_port' ] = int (os .getenv ("AJET_DAT_INTERCHANGE_PORT" ))
201+ exp_config ['ajet' ]['interchange_server' ]['interchange_server_port' ] = int (os .getenv ("AJET_DAT_INTERCHANGE_PORT" )) # type: ignore
124202 env , exp_config = setup_environment_vars (args , exp_config , main_yaml_fp )
125203
126204 # Start ray if not already started
@@ -163,11 +241,12 @@ async def start_engine():
163241
164242
165243 # --- engine status ---
166- shared_mem_dict ['engine_status' ] = "ENGINE.OFF "
244+ shared_mem_dict ['engine_status' ] = "ENGINE.OFFLINE "
167245 @app .post ("/update_engine_status" , response_model = BoolResponse )
168246 async def update_engine_status (req : UpdateEngineStatusRequest ):
247+ """Update the current engine status."""
169248 if req .engine_status not in [
170- "ENGINE.OFF " ,
249+ "ENGINE.OFFLINE " ,
171250 "ENGINE.BOOTING" ,
172251 "ENGINE.ROLLING" ,
173252 "ENGINE.WEIGHT_SYNCING" ,
@@ -180,14 +259,15 @@ async def update_engine_status(req: UpdateEngineStatusRequest):
180259
181260 @app .get ("/get_engine_status" )
182261 async def get_engine_status ():
262+ """Get the current engine status."""
183263 status = shared_mem_dict ['engine_status' ]
184264 return {"engine_status" : status }
185265
186266
187267 # --- episode status ---
188268 @app .post ("/register_episode" , response_model = BoolResponse )
189269 async def register_episode (req : RegisterEpisodeRequest ):
190-
270+ """(From task_runner) Register a new episode as ready to roll."""
191271 episode_uuid = req .episode_uuid
192272 es = EpisodeStatus (
193273 episode_uuid = req .episode_uuid ,
@@ -210,8 +290,30 @@ async def register_episode(req: RegisterEpisodeRequest):
210290
211291 @app .post ("/claim_episode" , response_model = ClaimEpisodeResponse )
212292 async def claim_episode (req : ClaimEpisodeRequest ):
293+ """(From client) Claim an available episode to rollout."""
213294 find_claimed_episodes_that_need_to_be_unclaimed ()
214295
296+ engine_status = shared_mem_dict ['engine_status' ]
297+ if engine_status != "ENGINE.ROLLING" :
298+ fail_cause = f"Engine not ready. Current status: [{ engine_status } ]."
299+ advise = ""
300+ if engine_status == "ENGINE.OFFLINE" :
301+ advise = "Please start the engine first. Please use one of the client to run `client.sync_train_config() + client.start_engine()` to start the engine."
302+ elif engine_status == "ENGINE.BOOTING" :
303+ advise = "Please wait until the engine is fully booted. Try again (maybe 1 minute) later."
304+ elif engine_status == "ENGINE.WEIGHT_SYNCING" :
305+ advise = "Engine is syncing weights. Try again (maybe 1 minute) later."
306+ elif engine_status == "ENGINE.WEIGHT_EXPORTING" :
307+ advise = "Engine is exporting weights (fsdp -> hf safetensor). Try again (maybe 1 minute) later."
308+ return ClaimEpisodeResponse (
309+ success = False ,
310+ client_uuid = req .client_uuid ,
311+ episode_uuid = "" ,
312+ openai_base_url = "" ,
313+ openai_api_key = "" ,
314+ fail_cause = fail_cause + " " + advise ,
315+ )
316+
215317 with shared_mem_dict_lock :
216318 if len (shared_mem_dict ['unclaimed_episodes' ]) <= 0 :
217319 return ClaimEpisodeResponse (
@@ -248,41 +350,6 @@ async def claim_episode(req: ClaimEpisodeRequest):
248350 )
249351
250352
251- def find_claimed_episodes_that_need_to_be_unclaimed () -> List [str ]:
252- result = []
253- current_time = time .time ()
254-
255- for k , v in shared_mem_dict .items ():
256- if k .startswith ("episodes-" ):
257- es :EpisodeStatus = v
258- if es .episode_status == "claimed" :
259- if (current_time - es .latest_activity_timestamp ) > es .allow_discard_timeout :
260- result .append (es .episode_uuid )
261-
262- for episode_uuid in result :
263- _revert_episode_to_unclaimed (episode_uuid )
264-
265- return result
266-
267-
268- def _revert_episode_to_unclaimed (episode_uuid : str ):
269- with shared_mem_dict_lock :
270- # check status again, because other thread may have changed it
271- if shared_mem_dict [f"episodes-{ episode_uuid } " ].episode_status != "claimed" :
272- return
273-
274- # revert
275- logger .info (f"Reverting episode { episode_uuid } to unclaimed due to client timeout." )
276- if f"episodes-{ episode_uuid } " in shared_mem_dict :
277- es :EpisodeStatus = shared_mem_dict [f"episodes-{ episode_uuid } " ]
278- es .episode_status = "registered"
279- es .client_uuid = ""
280- es .latest_activity_timestamp = time .time ()
281- es .allow_discard_timeout = - 1
282- shared_mem_dict [f"episodes-{ episode_uuid } " ] = es
283- shared_mem_dict ['unclaimed_episodes' ] += [episode_uuid ]
284-
285-
286353 @app .post ("/end_episode" , response_model = EndEpisodeResponse )
287354 async def end_episode (req : EndEpisodeRequest ):
288355 # receive workflow output data
@@ -312,6 +379,10 @@ async def end_episode(req: EndEpisodeRequest):
312379 for _ in range (5 ): # max 5 minutes wait
313380 try :
314381 if DEBUG : logger .info (f"[server] episode_uuid: { episode_uuid } | recv_string begin." )
382+ # <wait for>:
383+ # <from_sourcefile>: ajet/task_runner/tinkerscript_runner.py
384+ # <from_code>: zmq_socket.send_string("ack")
385+ # <expect>: "ack"
315386 result_str = socket .recv_string ()
316387 break
317388 except zmq .Again as e :
@@ -345,9 +416,4 @@ async def get_episode_buffer():
345416 return EpisodeBufferResponse (buffer = result )
346417
347418
348-
349- async def register_episode_ready_listener ():
350- pass
351-
352-
353419 return app , register_episode_ready_listener ()
0 commit comments