diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 9c50f5aafa..9f1b681208 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -151,7 +151,7 @@ ray start --head ray start --address= # run RFT -as-rft run --config +trinity run --config ``` If you follow the steps above, Trinity-RFT will send a request to the data module server, the data active iterator will be activated and compute difficulty scores for each sample in the raw dataset. After that, the data module server stores the result dataset into the database, when exploring begins, it will load the prepared dataset and continue the downstream steps. diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index d188ffb70b..9e8843a5b6 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -17,7 +17,7 @@ The algorithm design and analysis can be found in this [technical report](./opmd To try out the OPMD algorithm: ```shell -as-rft run --config scripts/config/gsm8k_opmd.yaml +trinity run --config scripts/config/gsm8k_opmd.yaml ``` Note that in this config file, `sync_iteration_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process). diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index daebabd4d1..f1c7e832ef 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -95,6 +95,20 @@ def both(config: Config) -> None: logger.info("Eval step finished.") +def activate_data_module(config_path: str): + """Check whether to activate data module and preprocess datasets.""" + from trinity.data.client import LOCAL_SERVER_URL, request + + logger.info("Activating data module...") + res = request( + url=LOCAL_SERVER_URL, + configPath=config_path, + ) + if res["return_code"] != 0: + logger.error(f"Failed to activate data module: {res['return_msg']}.") + return + + def main() -> None: """The main entrypoint.""" parser = argparse.ArgumentParser() @@ -111,6 +125,10 @@ def main() -> None: # TODO: support parse all args from command line config = load_config(args.config) config.check_and_update() + # try to activate data module + data_config = config.data + if data_config.dj_config_path or data_config.dj_process_desc: + activate_data_module(args.config) ray.init() if config.mode == "explore": explore(config)