Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environments/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ dependencies:
- python=3.10
- pip:
- py-data-juicer
- agentscope
- flask
- omegaconf
- sqlalchemy
- psycopg2
- networkx
- transformers
- "-e ..[dev]"
5 changes: 5 additions & 0 deletions environments/env_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
"env_name": "trinity_data",
"env_yaml": "environments/data.yaml",
"env_entry": "trinity/data/server.py"
},
"trinity.training": {
"env_name": "trinity",
"env_yaml": "environments/training.yaml",
"env_entry": "trinity/cli/server.py"
}
}
7 changes: 7 additions & 0 deletions environments/training.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: trinity
channels:
- defaults
dependencies:
- python=3.10
- pip:
- "-e ..[dev]"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"flask",
]

[project.scripts]
Expand Down
8 changes: 4 additions & 4 deletions scripts/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def main():
env_mapping = json.load(f)
for env_path, env_config in env_mapping.items():
env_name = env_config["env_name"]
print(f"Installing dependencies for module {env_name}...")
print(f"Installing dependencies for module [{env_name}]...")
# check if it's existing
res = subprocess.run(
f"{env_mng} env list | grep {env_name}", shell=True, text=True, stdout=subprocess.PIPE
)
if res.returncode == 0 and env_name in res.stdout:
print(f"Environment {env_name} already exists. Skipping...")
print(f"Environment [{env_name}] already exists. Skipping...")
else:
res = subprocess.run(
f'{env_mng} env create -f {env_config["env_yaml"]}'
Expand All @@ -39,9 +39,9 @@ def main():
shell=True,
)
if res.returncode == 0:
print(f"Environment {env_name} created successfully.")
print(f"Environment [{env_name}] created successfully.")
else:
print(f"Failed to create environment {env_name} with exit code {res.returncode}.")
print(f"Failed to create environment [{env_name}] with exit code {res.returncode}.")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions scripts/start_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main():
env_mapping = json.load(f)
for env_path, env_config in env_mapping.items():
env_name = env_config["env_name"]
print(f"Starting server for module {env_name}...")
print(f"Starting server for module [{env_name}]...")
timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(os.path.join(args.log_dir, f"{env_name}_{timestamp}_log.txt"), "w") as log_file:
server = subprocess.Popen(
Expand All @@ -38,7 +38,7 @@ def main():
shell=True,
)
servers.append(server)
print(f"Server of module {env_name} is started with PID {server.pid}")
print(f"Server of module [{env_name}] is started with PID {server.pid}")
for server in servers:
server.wait()

Expand Down
11 changes: 7 additions & 4 deletions trinity/data/client.py → trinity/cli/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import requests

LOCAL_SERVER_URL = "http://127.0.0.1:5000/data_workflow"
LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow"
LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_training"


def send_get_request(url: str, params: dict) -> None:
def send_get_request(url: str, params: dict):
"""
Send GET request with parameters.

Expand Down Expand Up @@ -33,7 +34,9 @@ def request(url, **kwargs):

if __name__ == "__main__":
res = request(
url=LOCAL_SERVER_URL,
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
configPath="examples/grpo_gsm8k/gsm8k.yaml",
)
print(res)
if res:
print(res)
print(res["message"])
35 changes: 20 additions & 15 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,35 @@ def both(config: Config) -> None:

def activate_data_module(config_path: str):
"""Check whether to activate data module and preprocess datasets."""
from trinity.data.client import LOCAL_SERVER_URL, request
from trinity.cli.client import LOCAL_DATA_WORKFLOW_SERVER_URL, request

logger.info("Activating data module...")
res = request(
url=LOCAL_SERVER_URL,
url=LOCAL_DATA_WORKFLOW_SERVER_URL,
configPath=config_path,
)
if res["return_code"] != 0:
logger.error(f"Failed to activate data module: {res['return_msg']}.")
return


def run(config_path: str):
# TODO: support parse all args from command line
config = load_config(config_path)
config.check_and_update()
# try to activate data module
data_config = config.data
if data_config.dj_config_path or data_config.dj_process_desc:
activate_data_module(config_path)
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
train(config)
elif config.mode == "both":
both(config)


def main() -> None:
"""The main entrypoint."""
parser = argparse.ArgumentParser()
Expand All @@ -132,19 +149,7 @@ def main() -> None:
args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
config = load_config(args.config)
config.check_and_update()
# try to activate data module
data_config = config.data
if data_config.dj_config_path or data_config.dj_process_desc:
activate_data_module(args.config)
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
train(config)
elif config.mode == "both":
both(config)
run(args.config)


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions trinity/cli/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import traceback

from flask import Flask, jsonify, request

app = Flask(__name__)

APP_NAME = "trinity_training"
PORT = 5006


@app.route(f"/{APP_NAME}", methods=["GET"])
def trinity_training():
config_path = request.args.get("configPath")
try:
from trinity.cli.launcher import run

run(config_path)
ret = 0
msg = "Training Success."
except: # noqa: E722
traceback.print_exc()
msg = traceback.format_exc()
ret = 1
return jsonify({"return_code": ret, "message": msg})


if __name__ == "__main__":
app.run(port=PORT, debug=True)
2 changes: 2 additions & 0 deletions trinity/data/controllers/default_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
},
"llm_quality_score_filter": {
"api_or_hf_model": "qwen2.5-72b-instruct",
"min_score": 0.0,
"enable_vllm": False,
},
"perplexity_filter": {
Expand All @@ -66,6 +67,7 @@
},
"llm_difficulty_score_filter": {
"api_or_hf_model": "qwen2.5-72b-instruct",
"min_score": 0.0,
"enable_vllm": False,
},
# human annotators
Expand Down
10 changes: 7 additions & 3 deletions trinity/data/core/dataset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool

from trinity.buffer.utils import retry_session
from trinity.common.config import DataConfig
from trinity.common.schema import Base, RftDatasetModel
from trinity.data.core.dataset import RftDataset
from trinity.manager.sql_storage import retry_session
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(self, config: DataConfig) -> None:
self.session = sessionmaker(bind=self.engine)

def add_entries(self, dataset: RftDataset):
with retry_session(self) as session:
with retry_session(
self, self.config.max_retry_times, self.config.max_retry_interval
) as session:
session.add_all(rft_dataset_to_model(dataset))

def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False):
Expand All @@ -65,7 +67,9 @@ def get_entries(self, num_entries: int, order_by: str = None, ascending: bool =
order_by_key = asc(order_by_key) if ascending else desc(order_by_key)
else:
order_by_key = None
with retry_session(self) as session:
with retry_session(
self, self.config.max_retry_times, self.config.max_retry_interval
) as session:
entries = (
session.query(RftDatasetModel)
.order_by(order_by_key)
Expand Down
4 changes: 2 additions & 2 deletions trinity/data/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ synth_data = synthesizer.process(clean_data)
- Request using our simple client:

```python
from trinity.data.client import request
from trinity.cli.client import request

res = request(
url="http://127.0.0.1:5000/data_workflow",
url="http://127.0.0.1:5005/data_workflow",
configPath="tests/test_configs/active_iterator_test_cfg.yaml"
)

Expand Down
13 changes: 8 additions & 5 deletions trinity/data/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from flask import Flask, jsonify, request

from trinity.common.config import load_config
from trinity.data.controllers.active_iterator import DataActiveIterator

app = Flask(__name__)

APP_NAME = "data_workflow"
PORT = 5005


@app.route("/data_workflow", methods=["GET"])
@app.route(f"/{APP_NAME}", methods=["GET"])
def data_workflow():
from trinity.common.config import load_config
from trinity.data.controllers.active_iterator import DataActiveIterator

config_path = request.args.get("configPath")
config = load_config(config_path)

Expand All @@ -17,4 +20,4 @@ def data_workflow():


if __name__ == "__main__":
app.run(debug=True)
app.run(port=PORT, debug=True)