Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ ray start --head
ray start --address=<master_address>

# run RFT
as-rft run --config <Trinity-RFT_config_path>
trinity run --config <Trinity-RFT_config_path>
```

If you follow the steps above, Trinity-RFT will send a request to the data module server, the data active iterator will be activated and compute difficulty scores for each sample in the raw dataset. After that, the data module server stores the result dataset into the database, when exploring begins, it will load the prepared dataset and continue the downstream steps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The algorithm design and analysis can be found in this [technical report](./opmd

To try out the OPMD algorithm:
```shell
as-rft run --config scripts/config/gsm8k_opmd.yaml
trinity run --config scripts/config/gsm8k_opmd.yaml
```

Note that in this config file, `sync_iteration_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process).
Expand Down
18 changes: 18 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ def both(config: Config) -> None:
logger.info("Eval step finished.")


def activate_data_module(config: Config, config_path: str):
"""Check whether to activate data module and preprocess datasets."""
data_config = config.data
if data_config.dj_config_path or data_config.dj_process_desc:
from trinity.data.client import LOCAL_SERVER_URL, request

logger.info("Activating data module...")
res = request(
url=LOCAL_SERVER_URL,
configPath=config_path,
)
if res["return_code"] != 0:
logger.error(f"Failed to activate data module: {res['return_msg']}.")
return


def main() -> None:
"""The main entrypoint."""
parser = argparse.ArgumentParser()
Expand All @@ -111,6 +127,8 @@ def main() -> None:
# TODO: support parse all args from command line
config = load_config(args.config)
config.check_and_update()
# try to activate data module
activate_data_module(config, args.config)
ray.init()
if config.mode == "explore":
explore(config)
Expand Down