Skip to content

Commit 5cc7297

Browse files
committed
feat: enhance TinkerScript integration with improved engine status handling and configuration updates
1 parent ae13326 commit 5cc7297

File tree

13 files changed

+165
-347
lines changed

13 files changed

+165
-347
lines changed

ajet/backbone/main_verl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import hydra
2323
import ray
2424
from beast_logger import print_dict
25-
from loguru import logger
2625
from omegaconf import OmegaConf
2726
from verl.trainer.ppo.reward import load_reward_manager
2827
from verl.utils.device import is_cuda_available
@@ -110,6 +109,7 @@ def run(self, config):
110109
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
111110
from pprint import pprint
112111

112+
from loguru import logger
113113
from omegaconf import OmegaConf
114114
from verl.utils.fs import copy_to_local
115115

ajet/backbone/main_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def main(config):
191191
start_interchange_server(config)
192192
if config.ajet.enable_tinkerscript_mode:
193193
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
194-
http_change_engine_status(config, "ROLLING")
194+
http_change_engine_status(config, "ENGINE.ROLLING")
195195

196196
def companion_launch():
197197
import torch

ajet/backbone/trainer_verl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def fit(self): # noqa: C901
559559
assert self.async_rollout_mode
560560
logger.info("=== wake up begin ===")
561561
self.async_rollout_manager.wake_up()
562-
self._update_interchange_server_status_flag("ROLLING")
562+
self._update_interchange_server_status_flag("ENGINE.ROLLING")
563563
logger.info("=== wake up end ===")
564564
tasks: List[Task] = [
565565
dict_to_ajet_task(dict(
@@ -585,7 +585,7 @@ def fit(self): # noqa: C901
585585
tasks, mode="sample", epoch=f"train.{epoch}"
586586
)
587587
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
588-
self._update_interchange_server_status_flag("UPDATE_WEIGHT")
588+
self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING")
589589
logger.info("begin to convert context_tracker_arr to dataproto")
590590
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
591591
logger.info("end convertion")

ajet/copilot/job.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,23 @@ def __init__(
4242
n_gpu: int = 8,
4343
algorithm: str = "grpo",
4444
n_gpu_for_infer: int | None = None, # only for trinity backbone
45+
grpo_n: int = 8,
46+
tinkerscript_mode: bool = True,
4547
*kwargs,
4648
) -> None:
4749
self.backbone = backbone
48-
self.config_as_dict: dict = self.build_job_from_yaml(None)
50+
if tinkerscript_mode:
51+
default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
52+
else:
53+
default_yaml = None
54+
self.config_as_dict: dict = self.build_job_from_yaml(default_yaml)
4955
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)
5056

5157
self.config.ajet.backbone = backbone
5258
self.config.ajet.model.path = model
5359
self.config.ajet.trainer_common.n_gpus_per_node = n_gpu
5460
self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm
61+
self.config.ajet.rollout.num_repeat = grpo_n
5562
if n_gpu_for_infer is None and backbone == "trinity":
5663
raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.")
5764
if (n_gpu_for_infer is not None) and backbone == "verl":

ajet/default_config/ajet_default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class AjetRollout:
3030
user_workflow: str = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow"
3131
n_vllm_engine: int = 1
3232
tensor_model_parallel_size: int = 1
33+
num_repeat: int = 8
3334

3435

3536
@dataclass

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ ajet:
292292
num_fastapi_process: 2 # 1, 2 or 4 is fine
293293
max_fastapi_threads: 128 # 64 or 128 is fine
294294
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
295+
already_started: False # do not edit, used by `tinkerscript`
295296

296297

297298
task_runner:

0 commit comments

Comments
 (0)