Skip to content

Commit 7660007

Browse files
committed
union_gen_batch_via_task_id is to be tested
1 parent a6c7e0e commit 7660007

File tree

10 files changed

+94
-52
lines changed

10 files changed

+94
-52
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,20 @@ def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | to
9999
return reward_tensor
100100

101101

102-
def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto):
102+
def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto, discard_original_batch=False):
103103
"""
104104
Union the gen_batch_output with the batch based on task_id.
105105
"""
106-
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
107-
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
108-
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
109-
batch_extend = batch.select_idxs(indices)
110-
batch_final = batch_extend.union(gen_batch_output)
111-
return batch_final
106+
if not discard_original_batch:
107+
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
108+
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
109+
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
110+
batch_extend = batch.select_idxs(indices)
111+
batch_final = batch_extend.union(gen_batch_output)
112+
return batch_final
113+
else:
114+
gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"]
115+
return gen_batch_output
112116

113117

114118
def compute_advantage(
@@ -550,16 +554,17 @@ def fit(self): # noqa: C901
550554
# pass global_steps to trace
551555
gen_batch.meta_info["global_steps"] = self.global_steps
552556
is_last_step = self.global_steps >= self.total_training_steps
553-
557+
from ajet import bp
558+
bp("BATCH")
554559
with marked_timer("step", timing_raw):
555560
# generate a batch
556-
logger.info("=== + rollout step begin ===")
561+
logger.info("rollout step begin")
557562
with marked_timer("gen", timing_raw, color="red"):
558563
assert self.async_rollout_mode
559-
logger.info("=== wake up begin ===")
564+
logger.info("wake up begin")
560565
self.async_rollout_manager.wake_up()
561566
self._update_interchange_server_status_flag("ENGINE.ROLLING")
562-
logger.info("=== wake up end ===")
567+
logger.info("wake up end")
563568
tasks: List[Task] = [
564569
dict_to_ajet_task(dict(
565570
task_id=gen_batch.non_tensor_batch["task_id"][i],
@@ -578,16 +583,14 @@ def fit(self): # noqa: C901
578583
]
579584
)
580585
)
581-
logger.info("=" * 10 + "start fit rollout" + "=" * 10)
586+
logger.info("start fit rollout")
582587
self.parallel_env.current_global_steps = self.global_steps
583588
context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout(
584589
tasks, mode="sample", epoch=f"train.{epoch}"
585590
)
586-
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
587-
self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING")
588-
logger.info("begin to convert context_tracker_arr to dataproto")
591+
logger.info("end fit rollout")
589592
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
590-
logger.info("end convertion")
593+
logger.info("end dataproto convertion")
591594

592595
success_rate = [
593596
traj.reward_structure.success_rate for traj in context_tracker_arr
@@ -630,17 +633,17 @@ def fit(self): # noqa: C901
630633
logger.info(
631634
f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}"
632635
)
636+
self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING")
633637
self.async_rollout_manager.sleep()
634-
logger.info("=== - rollout step end ===")
638+
logger.info("rollout step end")
635639

636-
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
637-
raise NotImplementedError("REMAX is not supported in GRPO yet.")
638640

639641
batch.non_tensor_batch["uid"] = np.array(
640642
[str(uuid.uuid4()) for _ in range(len(batch.batch))],
641643
dtype=object,
642644
)
643-
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output)
645+
discard_original_batch = self.config.ajet.enable_tinkerscript_mode
646+
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch)
644647
batch.batch["response_mask"] = compute_response_mask(batch)
645648

646649
if "response_mask" not in batch.batch.keys():
@@ -674,7 +677,7 @@ def fit(self): # noqa: C901
674677
)
675678

676679
# recompute old_log_probs
677-
logger.info("=== + compute log_probs begin ===")
680+
logger.info("+ compute log_probs begin")
678681
with marked_timer("old_log_prob", timing_raw, color="blue"):
679682
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
680683
entropys = old_log_prob.batch["entropys"]
@@ -946,7 +949,8 @@ def _validate(self):
946949
dtype=object,
947950
)
948951
tasks = tasks[: len(main_val_dataset)]
949-
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch)
952+
discard_original_batch = self.config.ajet.enable_tinkerscript_mode
953+
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch)
950954
# test_batch = test_batch.union(test_output_gen_batch)
951955
test_batch.meta_info["validate"] = True
952956

ajet/launcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ def setup_environment_vars(args, exp_config, main_yaml_fp):
176176
env["RAY_record_task_actor_creation_sites"] = "true"
177177
# assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore
178178
if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore
179-
exp_config["ajet"]["rollout"]["max_env_worker"] = 1
179+
# exp_config["ajet"]["rollout"]["max_env_worker"] = 1
180180
logger.warning(
181-
"For debugging mode, max_env_worker is set to 1 to facilitate debugging."
181+
"For debugging mode, please set max_env_worker to 1 to facilitate debugging."
182182
)
183183
logger.warning("Debug mode is ON")
184184
else:
@@ -206,7 +206,7 @@ def start_tinkerscript_server(env, config):
206206
assert config.ajet.enable_experimental_interchange_server, \
207207
"Please enable_experimental_interchange_server in config to start tinkerscript server."
208208
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
209-
start_interchange_server(config, blocking=True)
209+
start_interchange_server(config, blocking=True, env=env)
210210

211211

212212
def main():

ajet/schema/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99

1010
class Task(BaseModel):
11-
main_query: str = Field(default="")
12-
init_messages: List[dict] = Field(default=[])
13-
task_id: str = Field(default="")
14-
env_type: str = Field(default="")
15-
metadata: dict = Field(default_factory=dict)
11+
main_query: str = Field(default="", description="main query or instruction for the task, maybe absent if the task has valid init_messages.")
12+
init_messages: List[dict] = Field(default=[], description="initial messages for the task, maybe absent if the task has valid main_query.")
13+
task_id: str = Field(default="", description="same task_id mean same task, and of course, same GRPO group.")
14+
env_type: str = Field(default="", description="valid when the task need to interact with a gym env.")
15+
metadata: dict = Field(default_factory=dict, description="additional metadata for the task, e.g., reference answer for eval tasks.")
1616

1717

1818
"""

ajet/task_runner/tinkerscript_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s
4040
zmq_socket = zmq.Context().socket(zmq.REP)
4141
zmq_socket.bind(zmq_listen_result_addr)
4242
speicial_messages = [
43-
"RUNNER.RESET_CONTEXT_TRACKER"
43+
"RUNNER.SPECIAL.RESET_CONTEXT_TRACKER"
4444
]
4545
while True:
4646
# <wait for 1/2>:
@@ -103,6 +103,12 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
103103
context_tracker=context_tracker,
104104
)
105105

106+
# the most important thing is to fix task_id to client task_id, set task_id to workflow_task and context_tracker task_id
107+
assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id"
108+
task_id = workflow_output.metadata.get("task_id", "")
109+
workflow_task.task_id = task_id
110+
context_tracker.task_id = task_id
111+
106112
if workflow_output.reward is not None:
107113
raw_reward, is_success = (
108114
workflow_output.reward,

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ async def serve_with_monitor(additional_coro):
271271

272272

273273
# Convenience function for quick server startup
274-
def start_interchange_server(config, blocking=False) -> int:
274+
def start_interchange_server(config, blocking=False, env={}) -> int:
275275
# Read config
276276
already_started = config.ajet.interchange_server.already_started
277277
experiment_dir = config.ajet.experiment_dir
@@ -293,6 +293,9 @@ def start_interchange_server(config, blocking=False) -> int:
293293

294294
# init interchage server sub-process
295295
if not already_started:
296+
# apply env vars
297+
os.environ.update(env)
298+
# start interchange server
296299
interchange_server = InterchangeServer(
297300
experiment_dir,
298301
port,
@@ -342,6 +345,14 @@ def start_interchange_server(config, blocking=False) -> int:
342345
f"URL 1: {localhost_url}\n------\n"
343346
f"URL 2: {host_url}\n------\n"
344347
f"Press Ctrl+C to stop.")
345-
if interchange_server:
346-
interchange_server.join()
348+
try:
349+
if interchange_server:
350+
interchange_server.join()
351+
except KeyboardInterrupt:
352+
logger.info("Shutting down interchange server...")
353+
try: httpx.get(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code
354+
except Exception: pass
355+
356+
if interchange_server:
357+
interchange_server.terminate()
347358
return -1

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import yaml
55
from typing import List, Tuple
66
from loguru import logger
7-
from ajet.schema.task import WorkflowOutput
7+
from ajet.schema.task import WorkflowOutput, Task
88
from ajet.copilot.job import AgentJetJob
99
from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
1010
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import (
@@ -76,16 +76,19 @@ def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple
7676
logger.error(f"Error claiming episode: {e}. Retrying in 5s...")
7777
time.sleep(5)
7878

79-
def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput):
79+
def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput):
8080
if not episode_uuid:
8181
logger.error("No episode to end.")
8282
return
8383

8484
try:
85+
task_id = task.task_id
86+
workflow_output.metadata["task_id"] = task_id
8587
req_obj = EndEpisodeRequest(
8688
client_uuid=self.client_uuid,
8789
episode_uuid=episode_uuid,
88-
workflow_output=workflow_output
90+
workflow_output=workflow_output,
91+
task_id=task_id
8992
)
9093

9194
resp = httpx.post(

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,11 @@ def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dic
148148
# --------------------------------------------------------------------------------------
149149

150150
async def register_episode_ready_listener():
151-
while True:
152-
read_all_episode_status()
153-
await asyncio.sleep(10) # check every 10 seconds
154-
find_claimed_episodes_that_need_to_be_unclaimed()
151+
pass
152+
# while True:
153+
# read_all_episode_status()
154+
# await asyncio.sleep(10) # check every 10 seconds
155+
# find_claimed_episodes_that_need_to_be_unclaimed()
155156

156157
def read_all_episode_status() -> Optional[EpisodeStatus]:
157158
print_buffer = []
@@ -242,17 +243,26 @@ async def start_engine():
242243

243244
# Create args namespace
244245
args = SimpleNamespace(
245-
conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False, debug=False,
246+
conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False,
247+
debug=False,
246248
)
249+
# get debug param
250+
should_debug = os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1"
251+
debug_tags = os.environ.get("DEBUG_TAGS", "")
252+
if should_debug:
253+
args.debug = debug_tags
254+
255+
def override_param_callback(config):
256+
config['ajet']['interchange_server']['already_started'] = True
257+
config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore
258+
return config
247259

248260
# Finalize experiment config
249261
main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(
250-
main_yaml_fp, exp_dir_final, backbone
262+
main_yaml_fp, exp_dir_final, backbone, override_param_callback
251263
)
252264

253265
# Setup environment variables
254-
exp_config['ajet']['interchange_server']['already_started'] = True
255-
exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore
256266
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
257267

258268
# Start ray if not already started
@@ -421,6 +431,10 @@ async def end_episode(req: EndEpisodeRequest):
421431
client_uuid = req.client_uuid
422432
episode_uuid = req.episode_uuid
423433
workflow_output = req.workflow_output
434+
task_id = req.task_id
435+
436+
assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id"
437+
assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id"
424438

425439
if 'episodes' not in shared_mem_dict:
426440
logger.error(f"[server] No episodes registered yet.")

ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class EndEpisodeRequest(BaseModel):
3535
client_uuid: str
3636
episode_uuid: str
3737
workflow_output: WorkflowOutput
38+
task_id: str
3839

3940
class EndEpisodeResponse(BaseModel):
4041
success: bool

ajet/utils/config_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict:
168168

169169

170170
def read_ajet_hierarchical_config(
171-
yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments"
171+
yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments", override_param_callback=None
172172
):
173173
if yaml_fp is None:
174174
config = {
@@ -210,6 +210,9 @@ def read_ajet_hierarchical_config(
210210
config["defaults"].remove("trinity_default")
211211
config["hydra"]["searchpath"].remove("file://ajet/default_config/trinity")
212212

213+
if override_param_callback is not None:
214+
config = override_param_callback(config)
215+
213216
if write_to:
214217
with open(write_to, "w") as file:
215218
yaml.dump(config, file)
@@ -239,7 +242,7 @@ def expand_ajet_hierarchical_config(config, write_to=None):
239242
return config_final
240243

241244

242-
def prepare_experiment_config(yaml_path, exp_dir, backbone):
245+
def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None):
243246
"""
244247
Prepare experiment configuration by reading YAML, setting up backup directories,
245248
and copying necessary files for the experiment.
@@ -317,7 +320,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone):
317320

318321
## 4. edit new yaml
319322
config = read_ajet_hierarchical_config(
320-
yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir
323+
yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback
321324
)
322325
config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst)
323326

ajet_tinkerscript_threading.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ajet import WorkflowOutput
1111
from ajet.task_reader import RouterTaskReader
1212
from ajet.utils.retry import retry_with_backoff
13+
from ajet.schema.task import Task
1314
from concurrent.futures import ThreadPoolExecutor
1415

1516
# --------- configurations that take effect locally -------------
@@ -44,6 +45,7 @@ def main():
4445

4546
# Hand shake with remote tinkerscript server
4647
tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL)
48+
tinkerscript_remote.stop_engine()
4749
tinkerscript_remote.auto_sync_train_config_and_start_engine(
4850
AgentJetJob(
4951
algorithm="grpo",
@@ -52,8 +54,6 @@ def main():
5254
grpo_n=LOCAL_GRPO_N,
5355
)
5456
)
55-
# tinkerscript_remote.stop_engine()
56-
5757
# tinkerscript_remote = connect_to_tinkerscript_server(sync_train_config=False, start_engine=False)
5858
submit_sem = threading.BoundedSemaphore(LOCAL_MAX_PARALLEL)
5959

@@ -67,7 +67,7 @@ def rollout(task):
6767
# execute agent
6868
workflow_output = execute_agent(task, api_baseurl_key)
6969
# report output back to tinkerscript remote
70-
tinkerscript_remote.end_episode(episode_uuid, workflow_output)
70+
tinkerscript_remote.end_episode(task, episode_uuid, workflow_output)
7171
# collect reward
7272
group_reward.append(workflow_output.reward)
7373
print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }")
@@ -94,7 +94,7 @@ def rollout(task):
9494

9595

9696
@retry_with_backoff(max_retry=2)
97-
def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
97+
def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
9898
# Prepare base_url, api_key
9999
base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key)
100100
# Read dataset item

0 commit comments

Comments
 (0)