Skip to content

Commit 54a4721

Browse files
committed
+ add server for training
* unify client
1 parent 65f23ab commit 54a4721

File tree

5 files changed

+60
-25
lines changed

5 files changed

+60
-25
lines changed
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import requests
22

3-
LOCAL_SERVER_URL = "http://127.0.0.1:5000/data_workflow"
3+
LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow"
4+
LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_training"
45

56

67
def send_get_request(url: str, params: dict) -> None:
@@ -33,7 +34,9 @@ def request(url, **kwargs):
3334

3435
if __name__ == "__main__":
3536
res = request(
36-
url=LOCAL_SERVER_URL,
37+
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
3738
configPath="scripts/config/gsm8k.yaml",
3839
)
39-
print(res)
40+
if res:
41+
print(res)
42+
print(res['message'])

trinity/cli/launcher.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,33 @@ def both(config: Config) -> None:
100100

101101
def activate_data_module(config_path: str):
102102
"""Check whether to activate data module and preprocess datasets."""
103-
from trinity.data.client import LOCAL_SERVER_URL, request
103+
from trinity.cli.client import LOCAL_DATA_WORKFLOW_SERVER_URL, request
104104

105105
logger.info("Activating data module...")
106106
res = request(
107-
url=LOCAL_SERVER_URL,
107+
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
108108
configPath=config_path,
109109
)
110110
if res["return_code"] != 0:
111111
logger.error(f"Failed to activate data module: {res['return_msg']}.")
112112
return
113113

114+
def run(config_path: str):
115+
# TODO: support parse all args from command line
116+
config = load_config(config_path)
117+
config.check_and_update()
118+
# try to activate data module
119+
data_config = config.data
120+
if data_config.dj_config_path or data_config.dj_process_desc:
121+
activate_data_module(config_path)
122+
ray.init()
123+
if config.mode == "explore":
124+
explore(config)
125+
elif config.mode == "train":
126+
train(config)
127+
elif config.mode == "both":
128+
both(config)
129+
114130

115131
def main() -> None:
116132
"""The main entrypoint."""
@@ -126,19 +142,7 @@ def main() -> None:
126142
args = parser.parse_args()
127143
if args.command == "run":
128144
# TODO: support parse all args from command line
129-
config = load_config(args.config)
130-
config.check_and_update()
131-
# try to activate data module
132-
data_config = config.data
133-
if data_config.dj_config_path or data_config.dj_process_desc:
134-
activate_data_module(args.config)
135-
ray.init()
136-
if config.mode == "explore":
137-
explore(config)
138-
elif config.mode == "train":
139-
train(config)
140-
elif config.mode == "both":
141-
both(config)
145+
run(args.config)
142146

143147

144148
if __name__ == "__main__":

trinity/cli/server.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import traceback
2+
from flask import Flask, jsonify, request
3+
4+
app = Flask(__name__)
5+
6+
APP_NAME = 'trinity_training'
7+
PORT = 5006
8+
9+
@app.route(f"/{APP_NAME}", methods=["GET"])
10+
def trinity_training():
11+
config_path = request.args.get("configPath")
12+
try:
13+
from trinity.cli.launcher import run
14+
15+
run(config_path)
16+
ret = 0
17+
msg = "Training Success."
18+
except:
19+
traceback.print_exc()
20+
msg = traceback.format_exc()
21+
ret = 1
22+
return jsonify({"return_code": ret, "message": msg})
23+
24+
25+
if __name__ == "__main__":
26+
app.run(port=PORT, debug=True)

trinity/data/readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ synth_data = synthesizer.process(clean_data)
9292
- Request using our simple client:
9393

9494
```python
95-
from trinity.data.client import request
95+
from trinity.cli.client import request
9696

9797
res = request(
98-
url="http://127.0.0.1:5000/data_workflow",
98+
url="http://127.0.0.1:5005/data_workflow",
9999
configPath="tests/test_configs/active_iterator_test_cfg.yaml"
100100
)
101101

trinity/data/server.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from flask import Flask, jsonify, request
22

3-
from trinity.common.config import load_config
4-
from trinity.data.controllers.active_iterator import DataActiveIterator
5-
63
app = Flask(__name__)
74

5+
APP_NAME = 'data_workflow'
6+
PORT = 5005
87

9-
@app.route("/data_workflow", methods=["GET"])
8+
@app.route(f"/{APP_NAME}", methods=["GET"])
109
def data_workflow():
10+
from trinity.common.config import load_config
11+
from trinity.data.controllers.active_iterator import DataActiveIterator
12+
1113
config_path = request.args.get("configPath")
1214
config = load_config(config_path)
1315

@@ -17,4 +19,4 @@ def data_workflow():
1719

1820

1921
if __name__ == "__main__":
20-
app.run(debug=True)
22+
app.run(port=PORT, debug=True)

0 commit comments

Comments
 (0)