@@ -28,7 +28,7 @@ def __init__(self, server_url: str):
2828 self .previous_warning_time = 0
2929
3030
31- def begin_episode (self , allow_discard_timeout = 60 ) -> Tuple [str , OpenaiBaseUrlAndApiKey ]:
31+ def begin_episode (self , allow_discard_timeout = 60 , episode_type = "train" ) -> Tuple [str , OpenaiBaseUrlAndApiKey ]:
3232 """
3333 Block until an episode is claimed.
3434 Return (episode_uuid, openai_base_url, openai_api_key)
@@ -37,7 +37,7 @@ def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAnd
3737 try :
3838 req_obj = ClaimEpisodeRequest (
3939 client_uuid = self .client_uuid ,
40- episode_type = "default" ,
40+ episode_type = episode_type ,
4141 allow_discard_timeout = allow_discard_timeout ,
4242 )
4343 resp = httpx .post (
@@ -161,15 +161,15 @@ def start_engine(self):
161161 raise
162162
163163 # Poll until engine status is "ENGINE.ROLLING"
164- self ._wait_until_avail ( )
164+ self ._wait_until_status_change_to ( desired_status = "ENGINE.ROLLING" )
165165 logger .success ("Training engine is now ROLLING and ready." )
166166
167- def _wait_until_avail (self ):
167+ def _wait_until_status_change_to (self , desired_status = "ENGINE.ROLLING" ):
168168 """
169- Poll engine status until it reaches ENGINE.ROLLING state .
169+ Poll engine status until it reaches desired_status .
170170 Reports status every 5 seconds while waiting.
171171 """
172- logger .info ("Polling engine status until ENGINE.ROLLING ..." )
172+ logger .info (f "Polling engine status until { desired_status } ..." )
173173 last_report_time = time .time ()
174174 init_poll_time = last_report_time
175175
@@ -184,8 +184,8 @@ def _wait_until_avail(self):
184184 last_report_time = current_time
185185
186186 # Check if engine has reached the desired status
187- if current_status == "ENGINE.ROLLING" :
188- logger .info ("Engine status is ENGINE.ROLLING - engine is ready " )
187+ if current_status == desired_status :
188+ logger .info (f "Engine status is { desired_status } . " )
189189 break
190190
191191 # Wait a bit before next poll
@@ -256,7 +256,34 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob):
256256 logger .info ("Engine is already ROLLING. No action needed." )
257257 elif current_status == "ENGINE.BOOTING" :
258258 logger .info ("Engine is BOOTING. Waiting until it becomes ROLLING..." )
259- self ._wait_until_avail ( )
259+ self ._wait_until_status_change_to ( desired_status = "ENGINE.ROLLING" )
260260 logger .success ("Training engine is now ROLLING and ready." )
261261 else :
262262 raise RuntimeError (f"Cannot sync train config or start engine when engine is in status: { current_status } " )
263+
264+ def stop_engine (self ):
265+ """
266+ Stop the training engine on the TinkerScript server.
267+ This triggers the server to stop the training process.
268+ """
269+ current_status = self .get_engine_status ()
270+ if current_status == "ENGINE.OFFLINE" :
271+ logger .info ("Engine is already OFFLINE. No action needed." )
272+ return
273+
274+ try :
275+ resp = httpx .post (
276+ f"{ self .server_url } /stop_engine" ,
277+ json = {},
278+ timeout = 600
279+ )
280+ resp .raise_for_status ()
281+ result = resp .json ()
282+ if result .get ("success" ):
283+ logger .info ("Successfully stopped training engine on TinkerScript server" )
284+ else :
285+ logger .error ("Failed to stop training engine" )
286+ self ._wait_until_status_change_to (desired_status = "ENGINE.OFFLINE" )
287+ except Exception as e :
288+ logger .error (f"Error stopping engine: { e } " )
289+
0 commit comments