Skip to content

Commit f0a6024

Browse files
authored
Add training service (#13)
1 parent d93c816 commit f0a6024

File tree

14 files changed

+111
-38
lines changed

14 files changed

+111
-38
lines changed

environments/data.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ dependencies:
55
- python=3.10
66
- pip:
77
- py-data-juicer
8+
- agentscope
89
- flask
910
- omegaconf
1011
- sqlalchemy
1112
- psycopg2
1213
- networkx
1314
- transformers
15+
- "-e ..[dev]"

environments/env_mapping.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@
33
"env_name": "trinity_data",
44
"env_yaml": "environments/data.yaml",
55
"env_entry": "trinity/data/server.py"
6+
},
7+
"trinity.training": {
8+
"env_name": "trinity",
9+
"env_yaml": "environments/training.yaml",
10+
"env_entry": "trinity/cli/server.py"
611
}
712
}

environments/training.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name: trinity
2+
channels:
3+
- defaults
4+
dependencies:
5+
- python=3.10
6+
- pip:
7+
- "-e ..[dev]"

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ dependencies = [
3232
"math_verify",
3333
"ninja",
3434
"fire",
35+
"flask",
36+
"requests",
3537
]
3638

3739
[project.scripts]

scripts/install.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def main():
2424
env_mapping = json.load(f)
2525
for env_path, env_config in env_mapping.items():
2626
env_name = env_config["env_name"]
27-
print(f"Installing dependencies for module {env_name}...")
27+
print(f"Installing dependencies for module [{env_name}]...")
2828
# check if it's existing
2929
res = subprocess.run(
3030
f"{env_mng} env list | grep {env_name}", shell=True, text=True, stdout=subprocess.PIPE
3131
)
3232
if res.returncode == 0 and env_name in res.stdout:
33-
print(f"Environment {env_name} already exists. Skipping...")
33+
print(f"Environment [{env_name}] already exists. Skipping...")
3434
else:
3535
res = subprocess.run(
3636
f'{env_mng} env create -f {env_config["env_yaml"]}'
@@ -39,9 +39,9 @@ def main():
3939
shell=True,
4040
)
4141
if res.returncode == 0:
42-
print(f"Environment {env_name} created successfully.")
42+
print(f"Environment [{env_name}] created successfully.")
4343
else:
44-
print(f"Failed to create environment {env_name} with exit code {res.returncode}.")
44+
print(f"Failed to create environment [{env_name}] with exit code {res.returncode}.")
4545

4646

4747
if __name__ == "__main__":

scripts/start_servers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main():
2828
env_mapping = json.load(f)
2929
for env_path, env_config in env_mapping.items():
3030
env_name = env_config["env_name"]
31-
print(f"Starting server for module {env_name}...")
31+
print(f"Starting server for module [{env_name}]...")
3232
timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
3333
with open(os.path.join(args.log_dir, f"{env_name}_{timestamp}_log.txt"), "w") as log_file:
3434
server = subprocess.Popen(
@@ -38,7 +38,7 @@ def main():
3838
shell=True,
3939
)
4040
servers.append(server)
41-
print(f"Server of module {env_name} is started with PID {server.pid}")
41+
print(f"Server of module [{env_name}] is started with PID {server.pid}")
4242
for server in servers:
4343
server.wait()
4444

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import requests
22

3-
LOCAL_SERVER_URL = "http://127.0.0.1:5000/data_workflow"
43

5-
6-
def send_get_request(url: str, params: dict) -> None:
4+
def send_get_request(url: str, params: dict):
75
"""
86
Send GET request with parameters.
97
@@ -32,8 +30,15 @@ def request(url, **kwargs):
3230

3331

3432
if __name__ == "__main__":
33+
# --- only for local testing
34+
LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow"
35+
LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft"
36+
# --- only for local testing
37+
3538
res = request(
36-
url=LOCAL_SERVER_URL,
39+
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
3740
configPath="examples/grpo_gsm8k/gsm8k.yaml",
3841
)
39-
print(res)
42+
if res:
43+
print(res)
44+
print(res["message"])

trinity/cli/launcher.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,38 @@ def both(config: Config) -> None:
107107
ray.get(trainer.log_finalize.remote(step=train_iter_num))
108108

109109

110-
def activate_data_module(config_path: str):
110+
def activate_data_module(data_workflow_url: str, config_path: str):
111111
"""Check whether to activate data module and preprocess datasets."""
112-
from trinity.data.client import LOCAL_SERVER_URL, request
112+
from trinity.cli.client import request
113113

114114
logger.info("Activating data module...")
115115
res = request(
116-
url=LOCAL_SERVER_URL,
116+
url=data_workflow_url,
117117
configPath=config_path,
118118
)
119119
if res["return_code"] != 0:
120120
logger.error(f"Failed to activate data module: {res['return_msg']}.")
121121
return
122122

123123

124+
def run(config_path: str):
125+
config = load_config(config_path)
126+
config.check_and_update()
127+
# try to activate data module
128+
data_config = config.data
129+
if data_config.data_workflow_url and (
130+
data_config.dj_config_path or data_config.dj_process_desc
131+
):
132+
activate_data_module(data_config.data_workflow_url, config_path)
133+
ray.init()
134+
if config.mode == "explore":
135+
explore(config)
136+
elif config.mode == "train":
137+
train(config)
138+
elif config.mode == "both":
139+
both(config)
140+
141+
124142
def main() -> None:
125143
"""The main entrypoint."""
126144
parser = argparse.ArgumentParser()
@@ -135,19 +153,7 @@ def main() -> None:
135153
args = parser.parse_args()
136154
if args.command == "run":
137155
# TODO: support parse all args from command line
138-
config = load_config(args.config)
139-
config.check_and_update()
140-
# try to activate data module
141-
data_config = config.data
142-
if data_config.dj_config_path or data_config.dj_process_desc:
143-
activate_data_module(args.config)
144-
ray.init()
145-
if config.mode == "explore":
146-
explore(config)
147-
elif config.mode == "train":
148-
train(config)
149-
elif config.mode == "both":
150-
both(config)
156+
run(args.config)
151157

152158

153159
if __name__ == "__main__":

trinity/cli/server.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import traceback
2+
3+
import fire
4+
from flask import Flask, jsonify, request
5+
6+
app = Flask(__name__)
7+
8+
APP_NAME = "trinity_rft"
9+
10+
11+
@app.route(f"/{APP_NAME}", methods=["GET"])
12+
def trinity_training():
13+
config_path = request.args.get("configPath")
14+
try:
15+
from trinity.cli.launcher import run
16+
17+
run(config_path)
18+
ret = 0
19+
msg = "Training Success."
20+
except: # noqa: E722
21+
traceback.print_exc()
22+
msg = traceback.format_exc()
23+
ret = 1
24+
return jsonify({"return_code": ret, "message": msg})
25+
26+
27+
def main(port=5006):
28+
app.run(port=port, debug=True)
29+
30+
31+
if __name__ == "__main__":
32+
fire.Fire(main)

trinity/common/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class FormatConfig:
4141
class DataConfig:
4242
"""Data config"""
4343

44-
# TODO: add more
44+
data_workflow_url: Optional[str] = None
45+
4546
dataset_path: str = ""
4647
train_split: str = "train"
4748
eval_split: Optional[str] = None # TODO: check data format

0 commit comments

Comments
 (0)