diff --git a/environments/data.yaml b/environments/data.yaml index 879e74ac79..6acdf04dc9 100644 --- a/environments/data.yaml +++ b/environments/data.yaml @@ -5,9 +5,11 @@ dependencies: - python=3.10 - pip: - py-data-juicer + - agentscope - flask - omegaconf - sqlalchemy - psycopg2 - networkx - transformers + - "-e ..[dev]" diff --git a/environments/env_mapping.json b/environments/env_mapping.json index 385b3cfe79..4532c9d20d 100644 --- a/environments/env_mapping.json +++ b/environments/env_mapping.json @@ -3,5 +3,10 @@ "env_name": "trinity_data", "env_yaml": "environments/data.yaml", "env_entry": "trinity/data/server.py" + }, + "trinity.training": { + "env_name": "trinity", + "env_yaml": "environments/training.yaml", + "env_entry": "trinity/cli/server.py" } } diff --git a/environments/training.yaml b/environments/training.yaml new file mode 100644 index 0000000000..436a75e778 --- /dev/null +++ b/environments/training.yaml @@ -0,0 +1,7 @@ +name: trinity +channels: + - defaults +dependencies: + - python=3.10 + - pip: + - "-e ..[dev]" diff --git a/pyproject.toml b/pyproject.toml index 89c247e227..3cb55daafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "math_verify", "ninja", "fire", + "flask", + "requests", ] [project.scripts] diff --git a/scripts/install.py b/scripts/install.py index 8a4e318b23..7144eb1059 100644 --- a/scripts/install.py +++ b/scripts/install.py @@ -24,13 +24,13 @@ def main(): env_mapping = json.load(f) for env_path, env_config in env_mapping.items(): env_name = env_config["env_name"] - print(f"Installing dependencies for module {env_name}...") + print(f"Installing dependencies for module [{env_name}]...") # check if it's existing res = subprocess.run( f"{env_mng} env list | grep {env_name}", shell=True, text=True, stdout=subprocess.PIPE ) if res.returncode == 0 and env_name in res.stdout: - print(f"Environment {env_name} already exists. Skipping...") + print(f"Environment [{env_name}] already exists. Skipping...") else: res = subprocess.run( f'{env_mng} env create -f {env_config["env_yaml"]}' @@ -39,9 +39,9 @@ def main(): shell=True, ) if res.returncode == 0: - print(f"Environment {env_name} created successfully.") + print(f"Environment [{env_name}] created successfully.") else: - print(f"Failed to create environment {env_name} with exit code {res.returncode}.") + print(f"Failed to create environment [{env_name}] with exit code {res.returncode}.") if __name__ == "__main__": diff --git a/scripts/start_servers.py b/scripts/start_servers.py index d54e165847..2e6e74961f 100644 --- a/scripts/start_servers.py +++ b/scripts/start_servers.py @@ -28,7 +28,7 @@ def main(): env_mapping = json.load(f) for env_path, env_config in env_mapping.items(): env_name = env_config["env_name"] - print(f"Starting server for module {env_name}...") + print(f"Starting server for module [{env_name}]...") timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) with open(os.path.join(args.log_dir, f"{env_name}_{timestamp}_log.txt"), "w") as log_file: server = subprocess.Popen( @@ -38,7 +38,7 @@ def main(): shell=True, ) servers.append(server) - print(f"Server of module {env_name} is started with PID {server.pid}") + print(f"Server of module [{env_name}] is started with PID {server.pid}") for server in servers: server.wait() diff --git a/trinity/data/client.py b/trinity/cli/client.py similarity index 68% rename from trinity/data/client.py rename to trinity/cli/client.py index 5c5ca9e530..311de1b9d8 100644 --- a/trinity/data/client.py +++ b/trinity/cli/client.py @@ -1,9 +1,7 @@ import requests -LOCAL_SERVER_URL = "http://127.0.0.1:5000/data_workflow" - -def send_get_request(url: str, params: dict) -> None: +def send_get_request(url: str, params: dict): """ Send GET request with parameters. @@ -32,8 +30,15 @@ def request(url, **kwargs): if __name__ == "__main__": + # --- only for local testing + LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow" + LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft" + # --- only for local testing + res = request( - url=LOCAL_SERVER_URL, + url=LOCAL_DATA_WORKFLOW_SERVER_URL, configPath="examples/grpo_gsm8k/gsm8k.yaml", ) - print(res) + if res: + print(res) + print(res["message"]) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 2e980632e5..d36f190428 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -104,13 +104,13 @@ def both(config: Config) -> None: raise e -def activate_data_module(config_path: str): +def activate_data_module(data_workflow_url: str, config_path: str): """Check whether to activate data module and preprocess datasets.""" - from trinity.data.client import LOCAL_SERVER_URL, request + from trinity.cli.client import request logger.info("Activating data module...") res = request( - url=LOCAL_SERVER_URL, + url=data_workflow_url, configPath=config_path, ) if res["return_code"] != 0: @@ -118,6 +118,24 @@ def activate_data_module(config_path: str): return +def run(config_path: str): + config = load_config(config_path) + config.check_and_update() + # try to activate data module + data_config = config.data + if data_config.data_workflow_url and ( + data_config.dj_config_path or data_config.dj_process_desc + ): + activate_data_module(data_config.data_workflow_url, config_path) + ray.init() + if config.mode == "explore": + explore(config) + elif config.mode == "train": + train(config) + elif config.mode == "both": + both(config) + + def main() -> None: """The main entrypoint.""" parser = argparse.ArgumentParser() @@ -132,19 +150,7 @@ def main() -> None: args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line - config = load_config(args.config) - config.check_and_update() - # try to activate data module - data_config = config.data - if data_config.dj_config_path or data_config.dj_process_desc: - activate_data_module(args.config) - ray.init() - if config.mode == "explore": - explore(config) - elif config.mode == "train": - train(config) - elif config.mode == "both": - both(config) + run(args.config) if __name__ == "__main__": diff --git a/trinity/cli/server.py b/trinity/cli/server.py new file mode 100644 index 0000000000..bb792752ca --- /dev/null +++ b/trinity/cli/server.py @@ -0,0 +1,32 @@ +import traceback + +import fire +from flask import Flask, jsonify, request + +app = Flask(__name__) + +APP_NAME = "trinity_rft" + + +@app.route(f"/{APP_NAME}", methods=["GET"]) +def trinity_training(): + config_path = request.args.get("configPath") + try: + from trinity.cli.launcher import run + + run(config_path) + ret = 0 + msg = "Training Success." + except: # noqa: E722 + traceback.print_exc() + msg = traceback.format_exc() + ret = 1 + return jsonify({"return_code": ret, "message": msg}) + + +def main(port=5006): + app.run(port=port, debug=True) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/trinity/common/config.py b/trinity/common/config.py index d0026ed1e6..1434e7a833 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -41,7 +41,8 @@ class FormatConfig: class DataConfig: """Data config""" - # TODO: add more + data_workflow_url: Optional[str] = None + dataset_path: str = "" train_split: str = "train" eval_split: Optional[str] = None # TODO: check data format diff --git a/trinity/data/controllers/default_ops.py b/trinity/data/controllers/default_ops.py index 9b46f38751..547aac5b35 100644 --- a/trinity/data/controllers/default_ops.py +++ b/trinity/data/controllers/default_ops.py @@ -56,6 +56,7 @@ }, "llm_quality_score_filter": { "api_or_hf_model": "qwen2.5-72b-instruct", + "min_score": 0.0, "enable_vllm": False, }, "perplexity_filter": { @@ -66,6 +67,7 @@ }, "llm_difficulty_score_filter": { "api_or_hf_model": "qwen2.5-72b-instruct", + "min_score": 0.0, "enable_vllm": False, }, # human annotators diff --git a/trinity/data/core/dataset_db.py b/trinity/data/core/dataset_db.py index e4e95eb882..c557c0e020 100644 --- a/trinity/data/core/dataset_db.py +++ b/trinity/data/core/dataset_db.py @@ -5,10 +5,10 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool +from trinity.buffer.utils import retry_session from trinity.common.config import DataConfig from trinity.common.schema import Base, RftDatasetModel from trinity.data.core.dataset import RftDataset -from trinity.manager.sql_storage import retry_session from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -55,7 +55,9 @@ def __init__(self, config: DataConfig) -> None: self.session = sessionmaker(bind=self.engine) def add_entries(self, dataset: RftDataset): - with retry_session(self) as session: + with retry_session( + self, self.config.max_retry_times, self.config.max_retry_interval + ) as session: session.add_all(rft_dataset_to_model(dataset)) def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False): @@ -65,7 +67,9 @@ def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = order_by_key = asc(order_by_key) if ascending else desc(order_by_key) else: order_by_key = None - with retry_session(self) as session: + with retry_session( + self, self.config.max_retry_times, self.config.max_retry_interval + ) as session: entries = ( session.query(RftDatasetModel) .order_by(order_by_key) diff --git a/trinity/data/readme.md b/trinity/data/readme.md index d2ea3745cf..e331a5726f 100644 --- a/trinity/data/readme.md +++ b/trinity/data/readme.md @@ -92,10 +92,10 @@ synth_data = synthesizer.process(clean_data) - Request using our simple client: ```python - from trinity.data.client import request + from trinity.cli.client import request res = request( - url="http://127.0.0.1:5000/data_workflow", + url="http://127.0.0.1:5005/data_workflow", configPath="tests/test_configs/active_iterator_test_cfg.yaml" ) diff --git a/trinity/data/server.py b/trinity/data/server.py index 0cfdeb766b..08ca5ebfea 100644 --- a/trinity/data/server.py +++ b/trinity/data/server.py @@ -1,13 +1,16 @@ +import fire from flask import Flask, jsonify, request -from trinity.common.config import load_config -from trinity.data.controllers.active_iterator import DataActiveIterator - app = Flask(__name__) +APP_NAME = "data_workflow" + -@app.route("/data_workflow", methods=["GET"]) +@app.route(f"/{APP_NAME}", methods=["GET"]) def data_workflow(): + from trinity.common.config import load_config + from trinity.data.controllers.active_iterator import DataActiveIterator + config_path = request.args.get("configPath") config = load_config(config_path) @@ -16,5 +19,9 @@ def data_workflow(): return jsonify({"return_code": ret, "message": msg}) +def main(port=5005): + app.run(port=port, debug=True) + + if __name__ == "__main__": - app.run(debug=True) + fire.Fire(main)