Skip to content

Commit a6c7e0e

Browse files
committed
stage eval code ( to be tested )
1 parent 968c2cf commit a6c7e0e

File tree

6 files changed

+352
-118
lines changed

6 files changed

+352
-118
lines changed

ajet/context_tracker/base_tracker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,21 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
148148
<= max_model_len
149149
)
150150

151+
def reset(self):
152+
self.saved_timelines: List[List[ExtendedMessage]] = []
153+
self.current_context_status = ""
154+
self.terminal_rewards_dict = {}
155+
self.discarded = False
156+
self.is_terminated = False
157+
self.reward_structure: Union[Reward, None] = None
158+
self.context_time_cost = 0
159+
self.tag = ""
160+
self.current_batch_success_rate: float = float("-inf")
161+
self.current_batch_reward: float = float("-inf")
162+
self.already_mad_flag: bool = False
163+
self.round_cnt = 0
164+
self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None
165+
151166
def group_tokenize(self):
152167
raise NotImplementedError
153168

ajet/task_runner/tinkerscript_runner.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class TinkerScriptRunner(BaseAgentRunner):
2525

26-
def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput:
26+
def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str, context_tracker: BaseContextTracker) -> WorkflowOutput:
2727
"""Register the episode as ready in the TinkerScript data interchange center."""
2828
# parse episode_uuid, openai_base_url, openai_api_key
2929
zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow")
@@ -39,15 +39,30 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s
3939
# begin wait for result
4040
zmq_socket = zmq.Context().socket(zmq.REP)
4141
zmq_socket.bind(zmq_listen_result_addr)
42-
43-
# <wait for>:
44-
# <from_sourcefile>: ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py
45-
# <from_code>: socket.send_string(workflow_output.model_dump_json())
46-
# <expect>: workflow_output: WorkflowOutput
47-
message = zmq_socket.recv_string()
42+
speicial_messages = [
43+
"RUNNER.RESET_CONTEXT_TRACKER"
44+
]
45+
while True:
46+
# <wait for 1/2>:
47+
# <from_sourcefile>: ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py
48+
# <from_code>: socket.send_string(workflow_output.model_dump_json())
49+
# <expect>: workflow_output: WorkflowOutput
50+
# <wait for 2/2>:
51+
# <from_sourcefile>: ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py
52+
# <from_code>: socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER")
53+
# <expect>: "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER"
54+
message = zmq_socket.recv_string()
55+
if message not in speicial_messages:
56+
zmq_socket.send_string("ack")
57+
break
58+
elif message == "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER":
59+
logger.warning(f"Received reset command for episode {episode_uuid}.")
60+
context_tracker.reset()
61+
zmq_socket.send_string("ack")
62+
else:
63+
raise RuntimeError(f"Unknown special message received: {message}")
4864

4965
logger.success(f"Received workflow output for episode {episode_uuid}")
50-
zmq_socket.send_string("ack")
5166
zmq_socket.close()
5267
if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path)
5368

@@ -85,6 +100,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
85100
episode_uuid=context_tracker.episode_uuid,
86101
openai_base_url=base_url,
87102
openai_api_key=api_key,
103+
context_tracker=context_tracker,
88104
)
89105

90106
if workflow_output.reward is not None:

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, server_url: str):
2828
self.previous_warning_time = 0
2929

3030

31-
def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
31+
def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple[str, OpenaiBaseUrlAndApiKey]:
3232
"""
3333
Block until an episode is claimed.
3434
Return (episode_uuid, openai_base_url, openai_api_key)
@@ -37,7 +37,7 @@ def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAnd
3737
try:
3838
req_obj = ClaimEpisodeRequest(
3939
client_uuid=self.client_uuid,
40-
episode_type="default",
40+
episode_type=episode_type,
4141
allow_discard_timeout=allow_discard_timeout,
4242
)
4343
resp = httpx.post(
@@ -161,15 +161,15 @@ def start_engine(self):
161161
raise
162162

163163
# Poll until engine status is "ENGINE.ROLLING"
164-
self._wait_until_avail()
164+
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
165165
logger.success("Training engine is now ROLLING and ready.")
166166

167-
def _wait_until_avail(self):
167+
def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING"):
168168
"""
169-
Poll engine status until it reaches ENGINE.ROLLING state.
169+
Poll engine status until it reaches desired_status.
170170
Reports status every 5 seconds while waiting.
171171
"""
172-
logger.info("Polling engine status until ENGINE.ROLLING...")
172+
logger.info(f"Polling engine status until {desired_status}...")
173173
last_report_time = time.time()
174174
init_poll_time = last_report_time
175175

@@ -184,8 +184,8 @@ def _wait_until_avail(self):
184184
last_report_time = current_time
185185

186186
# Check if engine has reached the desired status
187-
if current_status == "ENGINE.ROLLING":
188-
logger.info("Engine status is ENGINE.ROLLING - engine is ready")
187+
if current_status == desired_status:
188+
logger.info(f"Engine status is {desired_status}.")
189189
break
190190

191191
# Wait a bit before next poll
@@ -256,7 +256,34 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob):
256256
logger.info("Engine is already ROLLING. No action needed.")
257257
elif current_status == "ENGINE.BOOTING":
258258
logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...")
259-
self._wait_until_avail()
259+
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
260260
logger.success("Training engine is now ROLLING and ready.")
261261
else:
262262
raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}")
263+
264+
def stop_engine(self):
265+
"""
266+
Stop the training engine on the TinkerScript server.
267+
This triggers the server to stop the training process.
268+
"""
269+
current_status = self.get_engine_status()
270+
if current_status == "ENGINE.OFFLINE":
271+
logger.info("Engine is already OFFLINE. No action needed.")
272+
return
273+
274+
try:
275+
resp = httpx.post(
276+
f"{self.server_url}/stop_engine",
277+
json={},
278+
timeout=600
279+
)
280+
resp.raise_for_status()
281+
result = resp.json()
282+
if result.get("success"):
283+
logger.info("Successfully stopped training engine on TinkerScript server")
284+
else:
285+
logger.error("Failed to stop training engine")
286+
self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE")
287+
except Exception as e:
288+
logger.error(f"Error stopping engine: {e}")
289+

0 commit comments

Comments
 (0)