diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 59fcaf62f1..21b715ff6c 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -186,8 +186,16 @@ You may customize the configurations in [`examples`](https://github.com/modelsco model: model_path: $MODEL_PATH/{model_name} -data: - dataset_path: $DATASET_PATH/{dataset_name} +buffer: + explorer_input: + taskset: + name: $TASKSET_NAME + path: $DATASET_PATH/{dataset_name} + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: $WORKFLOW_NAME + default_reward_fn_type: $REWARD_FN_NAME ``` Please refer to [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/) for more details. diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index a96ff29ff6..ed0fc87596 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -10,7 +10,7 @@ In addition, we need to configure the following parameters in both files. The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks. ```yaml -data: +global_config: batch_size: # The same checkpoint path model: @@ -18,10 +18,11 @@ model: # The same data_base path buffer: - train_dataset: - name: gsm8k_buffer - storage_type: queue - path: 'sqlite:///gsm8k.db' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' synchronizer: sync_method: 'checkpoint' diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 20242312f2..35c70389cb 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -31,31 +31,24 @@ Trinity-RFT uses a unified config file to manage all config items. For the data In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: ```yaml -data: +data_processor: # basic info - dataset_path: '/path/to/gsm8k' - dataset_config: + source_data_path: '/path/to/gsm8k' + load_kwargs: split: 'train' # only need the train split - format_config: # set the field mappings + format: # set the field mappings prompt_key: 'question' response_key: 'answer' # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - # downstream loading related - total_epochs: 1 - batch_size: 96 - default_workflow_type: 'math_workflow' ``` Here you can set the basic information for the GSM-8K dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training: -+ `dataset_path`: the path to the raw dataset. -+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library. -+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones. ++ `source_data_path`: the path to the raw dataset. ++ `load_kwargs`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library. ++ `format`: some dataset format config items, which are used to map original data field names to unified ones. + `db_url`: the URL of the postgresql database to store the result dataset. -+ `total_epochs`: the total number of epochs to train on this dataset. -+ `batch_size`: the training batch size. -+ `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details. In addition, there are several config items related to the data active iterator, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer. @@ -63,20 +56,16 @@ In addition, there are several config items related to the data active iterator, If you are not familiar with Data-Juicer, the data module provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example: ```yaml -data: +data_processor: # basic info - dataset_path: '/path/to/gsm8k' - dataset_config: + source_data_path: '/path/to/gsm8k' + load_kwargs: split: 'train' # only need the train split - format_config: # set the field mappings + format: # set the field mappings prompt_key: 'question' response_key: 'answer' # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - # downstream loading related - total_epochs: 1 - batch_size: 96 - default_workflow_type: 'math_workflow' #### new part about data active iterator dj_process_desc: 'Please compute difficulty scores for these math questions.' @@ -109,20 +98,16 @@ process: After preparing the Data-Juicer data processing recipe, you can set the `dj_config_path` item in the Trinity-RFT config file to the path to this recipe. For example: ```yaml -data: +data_processor: # basic info - dataset_path: '/path/to/gsm8k' - dataset_config: + source_data_path: '/path/to/gsm8k' + load_kwargs: split: 'train' # only need the train split - format_config: # set the field mappings + format: # set the field mappings prompt_key: 'question' response_key: 'answer' # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - # downstream loading related - total_epochs: 1 - batch_size: 96 - default_workflow_type: 'math_workflow' #### new part about data active iterator dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml' @@ -185,12 +170,12 @@ Trinity-RFT uses a unified config file to manage all config items. For the data In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: ```yaml -data: +data_processor: # basic info - dataset_path: 'tests/test_data/test_human_annotator' - dataset_config: + source_data_path: 'tests/test_data/test_human_annotator' + load_kwargs: split: 'train' # only need the train split - format_config: # set the field mappings + format: # set the field mappings prompt_key: 'prompt' chosen_key: 'chosen' rejected_key: 'rejected' @@ -198,10 +183,6 @@ data: dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml' # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - # downstream loading related - total_epochs: 20 - batch_size: 32 - default_workflow_type: 'math_workflow' ``` Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above. diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 0e01f879d3..5d274fbb47 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -51,7 +51,7 @@ buffer: train_dataset: storage_type: file path: <$DATASET_PATH/human_like_dpo_dataset> - kwargs: + format: prompt_type: # messages/plaintext prompt_key: chosen_key: diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index e9a6d9b594..de785d18fe 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -84,7 +84,7 @@ buffer: sft_warmup_dataset: storage_type: file path: <$DATASET_PATH/{sft_data}> - kwargs: + format: prompt_type: # messages/plaintext/chatpair prompt_key: response_key: diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 6983162cc7..74194ad57d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -2,6 +2,21 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example. +## Global Config + +```yaml +mode: both +global_config: + total_epochs: 1 + batch_size: 96 + eval_interval: 1000 +``` + +- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`. +- `global_config.total_epochs`: The total number of epochs. It should be checked manually. +- `global_config.batch_size`: The batch size used for training. It should be checked manually. +- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`. + ## Monitor @@ -15,45 +30,32 @@ monitor: - `monitor.name`: The name of the experiment. It must be set manually. -## Data +## Data Processing ```yaml -data: - dataset_path: '/PATH/TO/DATASET' - train_split: 'train' - eval_split: '' - dataset_config: - split: 'train' - format_config: +data_processor: + source_data_path: '/PATH/TO/DATASET' + load_kwargs: + split: 'train' # only need the train split + format: prompt_key: 'question' response_key: 'answer' - db_url: '' - max_retry_times: 3 - max_retry_interval: 1 - - total_epochs: 20 - batch_size: 96 - default_workflow_type: 'math_workflow' - default_reward_fn_type: 'countdown_reward' + # cleaner related + dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' + clean_strategy: 'iterative' + # db related + db_url: 'postgresql://{username}@localhost:5432/{db_name}' ``` -- `data.dataset_path`: The path to the dataset. -- `data.train_split`: The split name of the dataset used for training. Default is `train`. -- `data.eval_split`: The split name of the dataset used for eval. -- `data.dataset_config`: The configuration for the dataset. -- `data.format_config`: The configuration for the format of the dataset. +- `data.source_data_path`: The path to the source dataset. +- `data.load_kwargs`: The kwargs used in `datasets.load_dataset`. +- `data.format`: The format of the source dataset. It includes `prompt_key` and `response_key`. +- `data.dj_config_path`: The path to the Data-Juicer configuration. +- `data.clean_strategy`: The cleaning strategy used for `DataCleaner`, which iteratively cleans dataset until targets are met. - `data.db_url`: The URL of the database. -- `data.max_retry_times`: The maximum number of retries when loading the dataset from database. -- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database. -- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually. -- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `explorer.repeat_times`. It should be set manually. -- `data.default_workflow_type`: The default workflow type used for training. -- `data.default_reward_fn_type`: The default reward function type used for training. - - ## Model @@ -93,18 +95,40 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: countdown_buffer - storage_type: queue - algorithm_type: ppo - path: 'sqlite:///countdown.db' - sft_warmup_dataset: null + explorer_input: + taskset: + name: countdown + path: 'countdown_dataset/oneshot-split' + split: train + format: + prompt_key: 'question' + response_key: 'answer' + eval_tasksets: [] + default_workflow_type: 'math_workflow' + default_reward_fn_type: 'countdown_reward' + trainer_input: + experience_buffer: + name: countdown_buffer + storage_type: queue + path: 'sqlite:///countdown.db' + sft_warmup_dataset: null ``` -- `buffer.max_retry_times`: The maximum number of retries when loading the dataset from database. -- `buffer.max_retry_interval`: The maximum interval between retries when loading the dataset from database. -- `buffer.train_dataset`: The configuration of the training dataset. -- `buffer.sft_warmup_dataset`: The configuration of the SFT warmup dataset. +- `buffer.max_retry_times`: The maximum number of retries when loading the data from database. +- `buffer.max_retry_interval`: The maximum interval between retries when loading the data from database. +- `buffer.explorer_input.taskset`: The configuration of the taskset. +- `buffer.explorer_input.taskset.name`: The name of the taskset. +- `buffer.explorer_input.taskset.path`: The path to the taskset. +- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`. +- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`. +- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default. +- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`. +- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`. +- `buffer.trainer_input.experience_buffer`: The configuration of experience_buffer. +- `buffer.trainer_input.experience_buffer.name`: The name of the experience buffer. +- `buffer.trainer_input.experience_buffer.storage_type`: The storage type of the experience buffer. Default is `queue`. +- `buffer.trainer_input.experience_buffer.path`: The sql path to store the experience buffer. It can be empty to indicate not saving to the database. +- `buffer.trainer_input.sft_warmup_dataset`: The configuration of the SFT warmup dataset. The structure of `sft_warmup_dataset` is the similar to `buffer.explorer_input.taskset`. ## Explorer @@ -157,7 +181,7 @@ synchronizer: - `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`. Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`, `checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`. -- `synchronizer.sync_interval`: The interval between two synchronizations. Default is `10`. It should be set manually. +- `synchronizer.sync_interval`: The interval steps between two synchronizations. Default is `10`. It should be set manually. - `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`. ## Trainer @@ -176,8 +200,8 @@ trainer: - `trainer.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`. - `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually. - `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`. -- `trainer.eval_interval`: The interval between two evaluations. Default is `1000`. -- `trainer.save_interval`: The interval between two checkpoints. Default is `100`. +- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`. +- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`. ### veRL Trainer Configuration diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 582e99159d..f1c82a9f25 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -102,13 +102,18 @@ class ExampleWorkflow(Workflow): ### Step 3: Modify Configuration File -After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `data` domain to the newly registered `Workflow` name. +After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name. ```yaml -data: - # Other fields - default_workflow_type: example_workflow +buffer: # Other fields + explorer_input: + taskset: + name: taskset_name + path: 'path/to/taskset' + # Other fields + eval_tasksets: [] + default_workflow_type: example_workflow # Other fields ``` diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index 82ad223baa..65dc61ce94 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -1,17 +1,8 @@ mode: explore -data: - # basic info - dataset_path: /PATH/TO/DATASET/ - subset_name: '' - train_split: 'train' - eval_split: 'test' - format_config: - prompt_key: 'question' - response_key: 'answer' - # downstream loading related +global_config: total_epochs: 20 batch_size: 96 - default_workflow_type: 'math_workflow' + eval_interval: 10 model: model_path: /PATH/TO/MODEL/ max_prompt_tokens: 256 @@ -23,10 +14,21 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: gsm8k_buffer - storage_type: queue - path: 'sqlite:///gsm8k.db' + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + split: train + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' explorer: engine_type: vllm_async engine_num: 2 @@ -51,7 +53,6 @@ trainer: algorithm_type: grpo trainer_config_path: examples/async_gsm8k/verl_config.yaml sft_warmup_steps: 0 # Set to integer to enable sft warmup - eval_interval: 10 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index e67d325ca2..85a9afbf11 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -1,17 +1,8 @@ mode: train -data: - # basic info - dataset_path: /PATH/TO/DATASET/ - subset_name: '' - train_split: 'train' - eval_split: 'test' - format_config: - prompt_key: 'question' - response_key: 'answer' - # downstream loading related +global_config: total_epochs: 20 batch_size: 96 - default_workflow_type: 'math_workflow' + eval_interval: 10 model: model_path: /PATH/TO/MODEL/ max_prompt_tokens: 256 @@ -23,10 +14,20 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: gsm8k_buffer - storage_type: queue - path: 'sqlite:///gsm8k.db' + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' explorer: engine_type: vllm_async engine_num: 2 @@ -51,7 +52,6 @@ trainer: algorithm_type: grpo trainer_config_path: examples/async_gsm8k/verl_config.yaml sft_warmup_steps: 0 # Set to integer to enable sft warmup - eval_interval: 10 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 6763254dd9..1e812201a0 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -1,13 +1,7 @@ mode: train -data: +global_config: total_epochs: 20 batch_size: 32 # NOTE - train_split: "train" - dataset_path: '' - default_workflow_type: 'math_workflow' - format_config: - prompt_key: '' - response_key: '' model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' # NOTE max_prompt_tokens: 1792 @@ -19,15 +13,16 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: dpo_buffer - storage_type: file - path: '/PATH/TO/DATASET/' - kwargs: - prompt_type: plaintext # plaintext/messages - prompt_key: prompt - chosen_key: chosen - rejected_key: rejected + trainer_input: + experience_buffer: + name: dpo_buffer + storage_type: file + path: '/PATH/TO/DATASET/' + format: + prompt_type: plaintext # plaintext/messages/chatpair + prompt_key: prompt + chosen_key: chosen + rejected_key: rejected explorer: engine_type: vllm_async engine_num: 0 diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 6c70a12d75..b083d7874d 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -1,12 +1,6 @@ -data: +global_config: total_epochs: 20 batch_size: 4 - dataset_path: 'scripts/data_prepare/alfworld_data' - default_workflow_type: 'alfworld_workflow' - train_split: 'train' - eval_split: '' - format_config: - prompt_key: 'game_file' model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' max_prompt_tokens: 4096 @@ -18,10 +12,19 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: alfworld_buffer - storage_type: queue - path: 'sqlite:///alfworld.db' + explorer_input: + taskset: + name: alfworld + storage_type: file + path: 'scripts/data_prepare/alfworld_data' + format: + prompt_key: 'game_file' + default_workflow_type: 'alfworld_workflow' + trainer_input: + experience_buffer: + name: alfworld_buffer + storage_type: queue + path: 'sqlite:///alfworld.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index a5ea536bff..7e25bfaaee 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,12 +1,6 @@ -data: +data_processor: # basic info - dataset_path: 'openai/gsm8k' - subset_name: "main" - train_split: 'train' - eval_split: 'test' - format_config: - prompt_key: 'question' - response_key: 'answer' + source_data_path: 'openai/gsm8k' # data active iterator related dj_process_desc: 'Please compute difficulty scores for these math questions.' agent_model_name: 'qwen-max' @@ -17,10 +11,10 @@ data: clean_strategy: 'iterative' # db related db_url: '' - # downstream loading related +global_config: total_epochs: 1 batch_size: 96 - default_workflow_type: 'math_workflow' + eval_interval: 50 model: model_path: '/PATH/TO/MODEL/' max_prompt_tokens: 256 @@ -32,16 +26,37 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: gsm8k_buffer - storage_type: queue - path: 'sqlite:///gsm8k.db' - # sft_warmup_dataset: # Uncomment these to enable sft warmup - # name: warmup_data - # storage_type: file - # path: '/PATH/TO/WARMUP_DATA/' - # kwargs: - # prompt_type: plaintext + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' + # kwargs: + # prompt_type: plaintext explorer: engine_type: vllm_async engine_num: 2 @@ -67,7 +82,6 @@ trainer: algorithm_type: grpo trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' sft_warmup_steps: 0 # Set to integer to enable sft warmup - eval_interval: 50 save_interval: 100 # get_exp_strategy: 'LFU' monitor: diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index c1d1bc0f15..06f88fe818 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -1,18 +1,7 @@ -data: - # basic info - dataset_path: /PATH/TO/DATASET/ - # dataset_config: - train_split: train - eval_split: test - format_config: - prompt_key: 'question' - response_key: 'gt_answer' - # db related - db_url: '' - # downstream loading related +global_config: total_epochs: 20 batch_size: 288 - default_workflow_type: 'math_workflow' + eval_interval: 10 model: model_path: /PATH/TO/MODEL/ max_prompt_tokens: 1024 @@ -24,10 +13,20 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: math_buffer - storage_type: queue - path: 'sqlite:///math.db' + explorer_input: + taskset: + name: math + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'question' + response_key: 'gt_answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: 'sqlite:///math.db' explorer: engine_type: vllm_async engine_num: 2 @@ -53,7 +52,6 @@ trainer: algorithm_type: grpo trainer_config_path: 'examples/grpo_math/train_math.yaml' sft_warmup_steps: 0 # Set to integer to enable sft warmup - eval_interval: 10 save_interval: 100 monitor: cache_root_dir: "" diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index 1b85571e23..ffe30d44f0 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -1,12 +1,6 @@ -data: +global_config: total_epochs: 20 batch_size: 4 - dataset_path: 'scripts/data_prepare/sciworld_data' - default_workflow_type: 'sciworld_workflow' - train_split: 'train' - eval_split: '' - format_config: - prompt_key: 'game_file' model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' max_prompt_tokens: 4096 @@ -18,10 +12,19 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: sciworld_buffer - storage_type: queue - path: 'sqlite:///sciworld.db' + explorer_input: + taskset: + name: sciworld + storage_type: file + path: 'scripts/data_prepare/sciworld_data' + format: + prompt_key: 'game_file' + default_workflow_type: 'sciworld_workflow' + trainer_input: + experience_buffer: + name: sciworld_buffer + storage_type: queue + path: 'sqlite:///sciworld.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index 7bdebcf2fa..451495433f 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -1,12 +1,6 @@ -data: +global_config: total_epochs: 20 batch_size: 4 - dataset_path: 'scripts/data_prepare/webshop_data' - default_workflow_type: 'webshop_workflow' - train_split: 'train' - eval_split: '' - format_config: - prompt_key: 'task_id' model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' max_prompt_tokens: 4096 @@ -18,10 +12,19 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: webshop_buffer - storage_type: queue - path: 'sqlite:///webshop.db' + explorer_input: + taskset: + name: webshop + storage_type: file + path: 'scripts/data_prepare/webshop_data' + format: + prompt_key: 'task_id' + default_workflow_type: 'webshop_workflow' + trainer_input: + experience_buffer: + name: webshop_buffer + storage_type: queue + path: 'sqlite:///webshop.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index 35a2cfe169..dcfeee47db 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -1,11 +1,6 @@ -data: +global_config: total_epochs: 1 batch_size: 96 - dataset_path: '{path to datasets}/gsm8k' - default_workflow_type: 'math_workflow' - format_config: - prompt_key: 'question' - response_key: 'answer' model: model_path: '{path to models}/Qwen2.5-1.5B-Inst' max_prompt_tokens: 256 @@ -17,10 +12,20 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: gsm8k_buffer - storage_type: queue - path: 'sqlite:///gsm8k_opmd.db' + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: '{path to datasets}/gsm8k' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k_opmd.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index f1c1b4b31d..941c0ef97b 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -1,14 +1,7 @@ -data: +global_config: total_epochs: 20 batch_size: 96 - dataset_path: 'countdown_dataset/oneshot-split' - default_workflow_type: 'math_workflow' - train_split: 'train' - eval_split: '' - default_reward_fn_type: 'countdown_reward' - format_config: - prompt_key: 'question' - response_key: 'answer' + eval_interval: 1000 model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' max_prompt_tokens: 256 @@ -20,10 +13,21 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 - train_dataset: - name: countdown_buffer - storage_type: queue - path: 'sqlite:///countdown.db' + explorer_input: + taskset: + name: countdown + storage_type: file + path: 'countdown_dataset/oneshot-split' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + default_reward_fn_type: 'countdown_reward' + trainer_input: + experience_buffer: + name: countdown_buffer + storage_type: queue + path: 'sqlite:///countdown.db' explorer: engine_type: vllm_async engine_num: 2 @@ -49,7 +53,6 @@ trainer: algorithm_type: ppo trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' sft_warmup_steps: 0 - eval_interval: 1000 save_interval: 100 monitor: cache_root_dir: "" diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index deffa9d68a..e06b133256 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -3,7 +3,7 @@ from tests.tools import RayUnittestBase from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience @@ -13,9 +13,8 @@ def test_queue_buffer(self): total_num = 8 put_batch_size = 2 read_batch_size = 4 - meta = DatasetConfig( + meta = StorageConfig( name="test_buffer", - namespace="test_namespace", algorithm_type=AlgorithmType.PPO, storage_type=StorageType.QUEUE, ) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 13eb657585..61ebc46315 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -5,7 +5,7 @@ from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience @@ -17,7 +17,7 @@ def test_create_sql_buffer(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 - meta = DatasetConfig( + meta = StorageConfig( name="test_buffer", algorithm_type=AlgorithmType.PPO, path=f"sqlite:///{db_path}", diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 0f0ef93e76..9b1d5f9997 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -136,6 +136,7 @@ def setUp(self): self.config.explorer.tensor_parallel_size = 1 self.config.explorer.engine_num = 2 self.config.explorer.repeat_times = 2 + self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") diff --git a/tests/data/core/dataset_test.py b/tests/data/core/dataset_test.py index 522abb13a2..be6e765fbd 100644 --- a/tests/data/core/dataset_test.py +++ b/tests/data/core/dataset_test.py @@ -3,7 +3,7 @@ import os import unittest -from trinity.common.config import DataConfig, FormatConfig +from trinity.common.config import DataProcessorConfig, FormatConfig from trinity.common.rewards import AccuracyReward from trinity.common.task import TaskSet from trinity.common.workflows import MathWorkflow, SimpleWorkflow @@ -15,31 +15,29 @@ class TestRftDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", response_key="solution", solution_key="solution", ), ) - self.data_config_sample_level_setting = DataConfig( - dataset_path=os.path.join( + self.data_config_sample_level_setting = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10_with_rewfn_workflow", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", response_key="solution", solution_key="solution", @@ -64,8 +62,8 @@ def test_format_dataset(self): # apply formatters dataset.format( formatters=[ - BoxedMathAnswerFormatter(config=self.data_config.format_config), - RLHFFormatter(config=self.data_config.format_config), + BoxedMathAnswerFormatter(config=self.data_config.format), + RLHFFormatter(config=self.data_config.format), ] ) self.assertNotEqual(dataset.data, original_data) diff --git a/tests/data/core/formatter_test.py b/tests/data/core/formatter_test.py index 0c6dedc163..363c736ed9 100644 --- a/tests/data/core/formatter_test.py +++ b/tests/data/core/formatter_test.py @@ -3,7 +3,7 @@ import os import unittest -from trinity.common.config import DataConfig, FormatConfig +from trinity.common.config import DataProcessorConfig, FormatConfig from trinity.data.core.dataset import RftDataset from trinity.data.core.formatter import ( BoxedMathAnswerFormatter, @@ -18,16 +18,15 @@ class TestBoxedMathDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", response_key="answer", solution_key="solution", @@ -36,7 +35,7 @@ def setUp(self) -> None: ) def test_init(self): - formatter = BoxedMathAnswerFormatter(config=self.data_config.format_config) + formatter = BoxedMathAnswerFormatter(config=self.data_config.format) # test for existing configs self.assertEqual(formatter.config.prompt_key, "problem") self.assertEqual(formatter.config.response_key, "answer") @@ -50,7 +49,7 @@ def test_init(self): def test_transform(self): dataset = RftDataset(data_config=self.data_config, reward_schema="default") - formatter = BoxedMathAnswerFormatter(config=self.data_config.format_config) + formatter = BoxedMathAnswerFormatter(config=self.data_config.format) self.assertNotIn(formatter.config.response_key, dataset.data.column_names) dataset.format(formatter) self.assertIn(formatter.config.response_key, dataset.data.column_names) @@ -60,16 +59,15 @@ class TestRLHFFormatter(unittest.TestCase): """Test cases for RLHFFormatter""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", chat_template="User: {}\nAssistant: ", ), @@ -79,7 +77,7 @@ def test_render_template(self): sample = { "problem": "What is the capital of France?", } - formatter = RLHFFormatter(config=self.data_config.format_config) + formatter = RLHFFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual( res_sample[formatter.config.prompt_key], @@ -87,14 +85,14 @@ def test_render_template(self): ) def test_render_template_without_chat_template(self): - self.data_config.format_config.chat_template = "" + self.data_config.format.chat_template = "" sample = { "problem": "What is the capital of France?", } - formatter = RLHFFormatter(config=self.data_config.format_config) + formatter = RLHFFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual(res_sample[formatter.config.prompt_key], "What is the capital of France?") - self.data_config.format_config.chat_template = "User: {}\nAssistant: " + self.data_config.format.chat_template = "User: {}\nAssistant: " def test_render_template_with_tokenizer(self): # TODO @@ -109,16 +107,15 @@ class TestRewardFormatter(unittest.TestCase): """Test cases for RewardFormatter""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", chosen_key="chosen", rejected_key="rejected", @@ -132,7 +129,7 @@ def test_render_template(self): "chosen": "Paris", "rejected": "London", } - formatter = RewardFormatter(config=self.data_config.format_config) + formatter = RewardFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual( res_sample[formatter.config.prompt_key], @@ -142,13 +139,13 @@ def test_render_template(self): self.assertEqual(res_sample[formatter.config.rejected_key], "London") def test_render_template_without_chat_template(self): - self.data_config.format_config.chat_template = "" + self.data_config.format.chat_template = "" sample = { "problem": "What is the capital of France?", "chosen": "Paris", "rejected": "London", } - formatter = RewardFormatter(config=self.data_config.format_config) + formatter = RewardFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual(res_sample[formatter.config.prompt_key], "What is the capital of France?") self.assertEqual(res_sample[formatter.config.chosen_key], "Paris") @@ -167,16 +164,15 @@ class TestSFTFormatter(unittest.TestCase): """Test cases for SFTFormatter""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", response_key="answer", chat_template="User: {}\nAssistant: ", @@ -188,7 +184,7 @@ def test_render_template(self): "problem": "What is the capital of France?", "answer": "Paris", } - formatter = SFTFormatter(config=self.data_config.format_config) + formatter = SFTFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual( res_sample[formatter.config.prompt_key], @@ -197,16 +193,16 @@ def test_render_template(self): self.assertEqual(res_sample[formatter.config.response_key], "Paris") def test_render_template_without_chat_template(self): - self.data_config.format_config.chat_template = "" + self.data_config.format.chat_template = "" sample = { "problem": "What is the capital of France?", "answer": "Paris", } - formatter = SFTFormatter(config=self.data_config.format_config) + formatter = SFTFormatter(config=self.data_config.format) res_sample = formatter._render_template(sample) self.assertEqual(res_sample[formatter.config.prompt_key], "What is the capital of France?") self.assertEqual(res_sample[formatter.config.response_key], "Paris") - self.data_config.format_config.chat_template = "User: {}\nAssistant: " + self.data_config.format.chat_template = "User: {}\nAssistant: " def test_render_template_with_tokenizer(self): # TODO @@ -221,16 +217,15 @@ class TestComposedFormatter(unittest.TestCase): """Test cases for ComposedFormatter""" def setUp(self) -> None: - self.data_config = DataConfig( - dataset_path=os.path.join( + self.data_config = DataProcessorConfig( + source_data_path=os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "..", "test_data", "test_10", ), - dataset_config={"split": "train"}, - format_config=FormatConfig( + format=FormatConfig( prompt_key="problem", response_key="answer", solution_key="solution", @@ -245,8 +240,8 @@ def setUp(self) -> None: def test_compose(self): composed_formatter = ComposedFormatter( formatters=[ - BoxedMathAnswerFormatter(config=self.data_config.format_config), - SFTFormatter(config=self.data_config.format_config), + BoxedMathAnswerFormatter(config=self.data_config.format), + SFTFormatter(config=self.data_config.format), ] ) self.assertNotIn(composed_formatter.config.response_key, self.sample) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 99ec5a739c..013aa29f64 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -18,6 +18,8 @@ class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.global_config.total_epochs = 2 + self.config.global_config.batch_size = 4 self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.repeat_times = 2 @@ -25,8 +27,7 @@ def setUp(self): self.config.monitor.project = "Trinity-unittest" self.config.model.checkpoint_path = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 - self.config.explorer.eval_interval = 4 - self.config.trainer.eval_interval = 4 + self.config.global_config.eval_interval = 4 @abstractmethod def test_explorer(self): @@ -35,7 +36,10 @@ def test_explorer(self): class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): - self.config.data = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.eval_tasksets.append( + get_unittest_dataset_config("countdown", "test") + ) self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.explorer.use_v1 = True self.config.check_and_update() @@ -51,9 +55,8 @@ def test_explorer(self): class TestExplorerCountdownNoEval(BaseExplorerCase): def test_explorer(self): - self.config.data = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.data.eval_split = None self.config.explorer.use_v1 = False self.config.check_and_update() explore(self.config) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index f5160d3083..49f6f3d924 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -7,7 +7,7 @@ import torch from trinity.buffer.reader.queue_reader import QueueReader -from trinity.common.config import DatasetConfig, load_config +from trinity.common.config import StorageConfig, load_config from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel @@ -68,13 +68,16 @@ def setUp(self): self.config.explorer.max_timeout = 5 self.config.buffer.read_batch_size = 2 self.config.buffer.pad_token_id = 0 - self.config.buffer.train_dataset = DatasetConfig( + self.config.buffer.explorer_output = ( + self.config.buffer.trainer_input.experience_buffer + ) = StorageConfig( name="test", - namespace="test_runner_pool", storage_type=StorageType.QUEUE, algorithm_type=AlgorithmType.PPO, ) - self.queue = QueueReader(self.config.buffer.train_dataset, self.config.buffer) + self.queue = QueueReader( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) def test_runner_pool(self): pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()]) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 1f52409477..aee283bccb 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -1,12 +1,8 @@ mode: both -data: - dataset_path: '' +global_config: total_epochs: 1 batch_size: 4 - train_split: 'train' - eval_split: '' - default_workflow_type: '' - default_reward_fn_type: '' + eval_interval: 1000 model: model_path: '' max_prompt_tokens: 2048 @@ -18,6 +14,14 @@ cluster: # 2 for explorer, 2 for trainer buffer: max_retry_times: 3 max_retry_interval: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: '' + split: 'train' + default_workflow_type: '' + default_reward_fn_type: '' explorer: engine_type: vllm_async engine_num: 2 @@ -37,7 +41,6 @@ trainer: trainer_type: verl trainer_config_path: tests/template/verl_config.yaml sft_warmup_steps: 0 - eval_interval: 1000 save_interval: 100 monitor: project: unittest diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml index 94fa8e2960..3b105e1f66 100644 --- a/tests/test_configs/active_iterator_test_cfg.yaml +++ b/tests/test_configs/active_iterator_test_cfg.yaml @@ -1,13 +1,13 @@ -data: +data_processor: # basic info - dataset_path: '/Users/lielin/Projects/research/large_models/Trinity-RFT/tests/test_data/test_10/' - dataset_config: + source_data_path: 'tests/test_data/test_10/' + load_kwargs: split: 'train' - format_config: + format: prompt_key: 'problem' response_key: 'solution' # cleaner related - dj_config_path: '/Users/lielin/Projects/research/large_models/Trinity-RFT/tests/test_configs/active_iterator_test_dj_cfg.yaml' + dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' clean_strategy: 'iterative' # db related db_url: 'postgresql://{username}@localhost:5432/{db_name}' diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml index 7a0e0998e7..7f8581c0ef 100644 --- a/tests/test_configs/cleaner_test_rft_cfg.yaml +++ b/tests/test_configs/cleaner_test_rft_cfg.yaml @@ -1,5 +1,5 @@ -data: - dataset_path: './tests/test_data/test_cleaner' - dataset_config: {"split": "train"} +data_processor: + source_data_path: './tests/test_data/test_cleaner' + load_kwargs: {"split": "train"} dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' clean_strategy: 'iterative' diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml index dd6a248dea..79d8b8108b 100644 --- a/tests/test_configs/human_annotator_test_rft_cfg.yaml +++ b/tests/test_configs/human_annotator_test_rft_cfg.yaml @@ -1,8 +1,8 @@ -data: - dataset_path: './tests/test_data/test_human_annotator' - dataset_config: {"split": "train"} +data_processor: + source_data_path: './tests/test_data/test_human_annotator' + load_kwargs: {"split": "train"} dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' - format_config: + format: prompt_key: 'prompt' chosen_key: 'chosen' rejected_key: 'rejected' diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml index 56f6220d75..2ac9addf28 100644 --- a/tests/test_data/template.yaml +++ b/tests/test_data/template.yaml @@ -1,5 +1,4 @@ -data: - dataset_path: '' +global_config: batch_size: 32 model: max_prompt_tokens: 2048 @@ -11,6 +10,13 @@ cluster: buffer: max_retry_times: 3 max_retry_interval: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: '' + default_workflow_type: '' + default_reward_fn_type: '' explorer: engine_type: vllm engine_num: 2 diff --git a/tests/tools.py b/tests/tools.py index 1bbcc767ef..ff4488f857 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -6,7 +6,7 @@ import ray from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from trinity.common.config import Config, DataConfig, FormatConfig, load_config +from trinity.common.config import Config, FormatConfig, StorageConfig, load_config def get_template_config() -> Config: @@ -32,21 +32,21 @@ def get_checkpoint_path() -> str: return path -def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig: +def get_unittest_dataset_config( + dataset_name: str = "countdown", split: str = "train" +) -> StorageConfig: """Countdown sample dataset for 8 steps""" if dataset_name == "countdown": - return DataConfig( - total_epochs=2, - batch_size=4, - default_workflow_type="math_workflow", - default_reward_fn_type="countdown_reward", - dataset_path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), - train_split="train", - eval_split="test", - format_config=FormatConfig( + return StorageConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), + split=split, + format=FormatConfig( prompt_key="question", response_key="answer", ), + default_workflow_type="math_workflow", + default_reward_fn_type="countdown_reward", ) else: raise ValueError(f"Unknown dataset name: {dataset_name}") diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 9e68f71d70..35dada1074 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -22,17 +22,18 @@ class BaseTrainerCase(RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() + self.config.global_config.total_epochs = 2 + self.config.global_config.batch_size = 4 self.config.model.model_path = get_model_path() - self.config.trainer.engine_type = "vllm_async" - self.config.trainer.repeat_times = 3 + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.repeat_times = 3 self.config.monitor.monitor_type = MonitorType.TENSORBOARD self.config.model.checkpoint_path = os.path.join( get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}" ) self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL - self.config.explorer.eval_interval = 4 - self.config.trainer.eval_interval = 4 + self.config.global_config.eval_interval = 4 @abstractmethod def test_trainer(self): @@ -42,7 +43,10 @@ def test_trainer(self): class TestTrainerCountdown(BaseTrainerCase): def test_trainer(self): """Test the trainer.""" - self.config.data = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.eval_tasksets.append( + get_unittest_dataset_config("countdown", "test") + ) self.config.check_and_update() self.config.trainer.trainer_config.trainer.save_freq = 8 both(self.config) diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index c01ca621f1..09ff663c47 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -4,7 +4,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.buffer_writer import BufferWriter -from trinity.common.config import BufferConfig, Config, DatasetConfig +from trinity.common.config import BufferConfig, Config, StorageConfig from trinity.common.constants import StorageType @@ -13,48 +13,53 @@ class Buffer: """Responsible for storing experiences.""" def __init__(self, config: Config): - self.buffer_mapping: dict[str, DatasetConfig] = {} + self.buffer_mapping: dict[str, StorageConfig] = {} self._register_from_config(config) - def get_dataset_info(self, dataset_name: str) -> DatasetConfig: - dataset_config = self.buffer_mapping.get(dataset_name, None) - if dataset_config is None: + def get_dataset_info(self, dataset_name: str) -> StorageConfig: + storage_config = self.buffer_mapping.get(dataset_name, None) + if storage_config is None: raise ValueError(f"{dataset_name} not found.") - return dataset_config + return storage_config - def register_dataset(self, dataset_config: DatasetConfig) -> None: - if dataset_config.name in self.buffer_mapping: - raise ValueError(f"{dataset_config.name} already exists.") - self.buffer_mapping[dataset_config.name] = dataset_config + def register_dataset(self, storage_config: StorageConfig) -> None: + if storage_config.name in self.buffer_mapping: + raise ValueError(f"{storage_config.name} already exists.") + self.buffer_mapping[storage_config.name] = storage_config -def get_buffer_reader(dataset_config: DatasetConfig, buffer_config: BufferConfig) -> BufferReader: +def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferReader: """Get a buffer reader for the given dataset name.""" - if dataset_config.storage_type == StorageType.SQL: + if storage_config.storage_type == StorageType.SQL: from trinity.buffer.reader.sql_reader import SQLReader - return SQLReader(dataset_config, buffer_config) - elif dataset_config.storage_type == StorageType.QUEUE: + return SQLReader(storage_config, buffer_config) + elif storage_config.storage_type == StorageType.QUEUE: from trinity.buffer.reader.queue_reader import QueueReader - return QueueReader(dataset_config, buffer_config) - elif dataset_config.storage_type == StorageType.FILE: - from trinity.buffer.reader.file_reader import FileReader + return QueueReader(storage_config, buffer_config) + elif storage_config.storage_type == StorageType.FILE: + from trinity.buffer.reader.file_reader import FILE_READERS - return FileReader(dataset_config, buffer_config) + file_read_type = storage_config.algorithm_type + if file_read_type is not None: + file_read_type = file_read_type.value + else: + file_read_type = "rollout" + return FILE_READERS.get(file_read_type)(storage_config, buffer_config) else: - raise ValueError(f"{dataset_config.storage_type} not supported.") + raise ValueError(f"{storage_config.storage_type} not supported.") -def get_buffer_writer(dataset_config: DatasetConfig, buffer_config: BufferConfig) -> BufferWriter: +def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferWriter: """Get a buffer writer for the given dataset name.""" - if dataset_config.storage_type == StorageType.SQL: + if storage_config.storage_type == StorageType.SQL: from trinity.buffer.writer.sql_writer import SQLWriter - return SQLWriter(dataset_config, buffer_config) - elif dataset_config.storage_type == StorageType.QUEUE: + return SQLWriter(storage_config, buffer_config) + elif storage_config.storage_type == StorageType.QUEUE: from trinity.buffer.writer.queue_writer import QueueWriter - return QueueWriter(dataset_config, buffer_config) + return QueueWriter(storage_config, buffer_config) else: - raise ValueError(f"{dataset_config.storage_type} not supported.") + raise ValueError(f"{storage_config.storage_type} not supported.") diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 0f135af0bc..a360182f07 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -6,7 +6,7 @@ import ray from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -16,12 +16,12 @@ class QueueActor: FINISH_MESSAGE = "$FINISH$" - def __init__(self, dataset_config: DatasetConfig, config: BufferConfig) -> None: + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.config = config self.capacity = getattr(config, "capacity", 10000) self.queue = asyncio.Queue(self.capacity) - if dataset_config.path is not None and len(dataset_config.path) > 0: - sql_config = deepcopy(dataset_config) + if storage_config.path is not None and len(storage_config.path) > 0: + sql_config = deepcopy(storage_config) sql_config.storage_type = StorageType.SQL self.sql_writer = SQLWriter(sql_config, self.config) else: diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 21eba09d04..0ca202b2fa 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -2,55 +2,41 @@ from typing import List, Optional +import datasets import transformers from datasets import load_dataset from trinity.buffer.buffer_reader import BufferReader -from trinity.common.config import BufferConfig, DatasetConfig -from trinity.common.constants import ( - AlgorithmType, - PromptType, - ReadStrategy, - StorageType, -) +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType from trinity.common.experience import Experience +from trinity.common.rewards import REWARD_FUNCTIONS +from trinity.common.task import Task +from trinity.common.workflows import WORKFLOWS +from trinity.utils.registry import Registry +FILE_READERS = Registry("file_readers") -class FileReader(BufferReader): - """Reader of the File buffer.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None: - assert meta.storage_type == StorageType.FILE - if meta.algorithm_type == AlgorithmType.SFT: - self.reader = SFTDataReader(meta, config) - elif meta.algorithm_type == AlgorithmType.DPO: - self.reader = DPODataReader(meta, config) - else: - # TODO: support read rollout task - raise ValueError(f"Unsupported algorithm type: {meta.algorithm_type}") - - def read(self, strategy: Optional[ReadStrategy] = None) -> List: - """Read data from the buffer.""" - if strategy is not None and strategy != ReadStrategy.FIFO: - raise ValueError(f"Unsupported read strategy: {strategy}") - return self.reader.read() - - -class SFTDataReader: +@FILE_READERS.register_module(AlgorithmType.SFT.value) +class SFTDataReader(BufferReader): """Reader for SFT file data.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig): - self.train_split = meta.kwargs.get("train_split", "train") - self.prompt_type = PromptType(meta.kwargs.get("prompt_type", "messages")) - self.messages_key = meta.kwargs.get("messages_key", "messages") - self.prompt_key = meta.kwargs.get("prompt_key", "prompt") - self.response_key = meta.kwargs.get("response_key", "response") + def __init__(self, meta: StorageConfig, config: BufferConfig): + self.split = meta.split + subset_name = meta.subset_name + self.prompt_type = meta.format.prompt_type + self.messages_key = meta.format.messages_key + self.prompt_key = meta.format.prompt_key + self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size - self.dataset = load_dataset(meta.path)[self.train_split] + self.dataset = load_dataset( + meta.path, name=subset_name, split=self.split + ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) - def read(self) -> List: + def read(self, strategy: Optional[ReadStrategy] = None) -> List: try: batch_data = next(self.data_iter) except StopIteration: @@ -111,15 +97,19 @@ def read(self) -> List: return exp_list -class DPODataReader: - def __init__(self, meta: DatasetConfig, config: BufferConfig): - self.train_split = meta.kwargs.get("train_split", "train") - self.prompt_type = PromptType(meta.kwargs.get("prompt_type", "messages")) - self.prompt_key = meta.kwargs.get("prompt_key", "prompt") - self.chosen_key = meta.kwargs.get("chosen_key", "chosen") - self.rejected_key = meta.kwargs.get("rejected_key", "rejected") +@FILE_READERS.register_module(AlgorithmType.DPO.value) +class DPODataReader(BufferReader): + def __init__(self, meta: StorageConfig, config: BufferConfig): + self.split = meta.split + subset_name = meta.subset_name + self.prompt_type = meta.format.prompt_type + self.prompt_key = meta.format.prompt_key + self.chosen_key = meta.format.chosen_key + self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size - self.dataset = load_dataset(meta.path)[self.train_split] + self.dataset = load_dataset( + meta.path, name=subset_name, split=self.split + ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -131,7 +121,7 @@ def _get_assistant_message(self, item) -> dict: else: return item - def read(self) -> List: + def read(self, strategy: Optional[ReadStrategy] = None) -> List: try: batch_data = next(self.data_iter) except StopIteration: @@ -178,3 +168,66 @@ def read(self) -> List: ) exp_list.append(experience) return exp_list + + +@FILE_READERS.register_module("rollout") +class RolloutDataReader(BufferReader): + def __init__(self, meta: StorageConfig, config: BufferConfig): + self.name = meta.name + self.split = meta.split + subset_name = meta.subset_name + # disable datasets caching to avoid reuse old-version dataset + datasets.disable_caching() + self.dataset = load_dataset( + meta.path, name=subset_name, split=self.split + ) # TODO: may from db_url + # if task_type != TaskType.EVAL and config.db_url != "": + # logger.info(f"Loading dataset from database with url: {config.db_url}") + # db_type = config.db_url.split(":")[0] + # db_name = config.db_url.split("/")[-1] + # dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") + datasets.enable_caching() + self.index = meta.index # TODO: apply shuffle + + self.prompt_key = meta.format.prompt_key + self.response_key = meta.format.response_key + self.workflow_key = meta.format.workflow_key + self.reward_fn_key = meta.format.reward_fn_key + + self.task_type = meta.task_type + self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) + self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1 + + def __len__(self): + return len(self.dataset) + + def read(self, strategy: Optional[ReadStrategy] = None): + if self.index >= len(self.dataset) * self.total_epochs: + raise StopIteration + sample = self.dataset[self.index % len(self.dataset)] + task_desc = sample[self.prompt_key] if self.prompt_key in sample else None + truth = sample[self.response_key] if self.response_key in sample else None + workflow_class = ( + WORKFLOWS.get(sample[self.workflow_key]) + if self.workflow_key in sample + else self.default_workflow_cls + ) + reward_fn = ( + REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) + if self.reward_fn_key in sample + else self.default_reward_fn_cls + ) + assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required" + task = Task( + task_desc=task_desc, + truth=truth, + workflow=workflow_class, + reward_fn=reward_fn, + raw=sample, + task_type=self.task_type, + ) + self.index += 1 + if self.task_type == TaskType.EVAL and self.index == len(self.dataset): + self.index = 0 + return task diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 519479848c..ffd013d4ef 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -6,7 +6,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.queue import QueueActor -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType from trinity.utils.log import get_logger @@ -16,12 +16,11 @@ class QueueReader(BufferReader): """Reader of the Queue buffer.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig): + def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config self.queue = QueueActor.options( name=f"queue-{meta.name}", - namespace=meta.namespace, get_if_exists=True, ).remote(meta, config) diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index e5c249f441..4da2920816 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -11,7 +11,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.schema import Base, create_dynamic_table from trinity.buffer.utils import retry_session -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType from trinity.utils.log import get_logger @@ -21,7 +21,7 @@ class SQLReader(BufferReader): """Reader of the SQL buffer.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None: + def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL self.engine = create_engine(meta.path, poolclass=NullPool) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index febbb83a2f..db2e4ca137 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -1,6 +1,6 @@ """Schema for SQLAlchemy models.""" -from typing import Any, Optional +from typing import Any, Optional, Union from sqlalchemy import Column, Float, Integer, LargeBinary, String from sqlalchemy.ext.declarative import declarative_base @@ -126,7 +126,7 @@ def to_experience(self) -> Experience: SCHEMA_MAPPING = { - AlgorithmType.ROLLOUT: TaskModel, + None: TaskModel, AlgorithmType.SFT: SFTDataModel, AlgorithmType.PPO: ExperienceModel, AlgorithmType.GRPO: ExperienceModel, @@ -135,7 +135,7 @@ def to_experience(self) -> Experience: } -def create_dynamic_table(algorithm_type: AlgorithmType, table_name: str) -> Any: +def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any: """Create a dynamic table based on the provided algorithm type and table name.""" if algorithm_type not in SCHEMA_MAPPING: raise ValueError(f"Unknown schema: {algorithm_type}") diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 14fd9e5b28..5cd24877d9 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -5,7 +5,7 @@ from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.queue import QueueActor -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType from trinity.utils.log import get_logger @@ -15,12 +15,11 @@ class QueueWriter(BufferWriter): """Writer of the Queue buffer.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig): + def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config self.queue = QueueActor.options( name=f"queue-{meta.name}", - namespace=meta.namespace, get_if_exists=True, ).remote(meta, config) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index a2abb1e399..7464064037 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -8,7 +8,7 @@ from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.schema import Base, create_dynamic_table from trinity.buffer.utils import retry_session -from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType from trinity.utils.log import get_logger @@ -18,7 +18,7 @@ class SQLWriter(BufferWriter): """Writer of the SQL buffer.""" - def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None: + def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now # TODO: support other algorithms diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 21780e0e03..482224918d 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -122,7 +122,7 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - if train_step_num % config.trainer.eval_interval == 0: + if explore_step_num % config.global_config.eval_interval == 0: try: ray.get(explorer.eval.remote()) logger.info("Evaluation finished.") @@ -152,13 +152,13 @@ def run(config_path: str): config = load_config(config_path) config.check_and_update() # try to activate data module - data_config = config.data - if data_config.data_workflow_url and ( - data_config.dj_config_path or data_config.dj_process_desc + data_processor_config = config.data_processor + if data_processor_config.data_workflow_url and ( + data_processor_config.dj_config_path or data_processor_config.dj_process_desc ): - activate_data_module(data_config.data_workflow_url, config_path) + activate_data_module(data_processor_config.data_workflow_url, config_path) if not ray.is_initialized(): - ray.init() + ray.init(namespace=f"{config.monitor.project}-{config.monitor.name}") if config.mode == "explore": explore(config) elif config.mode == "train": diff --git a/trinity/common/config.py b/trinity/common/config.py index 0c44f7c2e6..00bdb153a4 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -2,7 +2,7 @@ """Configs for RFT.""" import os from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from omegaconf import OmegaConf @@ -12,6 +12,7 @@ PromptType, StorageType, SyncMethod, + TaskType, ) from trinity.utils.log import get_logger @@ -22,8 +23,11 @@ class FormatConfig: """Configuration for data formatting""" - prompt_key: str = "" - response_key: str = "" + prompt_type: PromptType = PromptType.MESSAGES + + prompt_key: str = "prompt" + response_key: str = "response" + messages_key: str = "message" chat_template: str = "" # for sample-level task controlling @@ -36,27 +40,49 @@ class FormatConfig: reward_key: str = "" # for dpo dataset - chosen_key: str = "" - rejected_key: str = "" + chosen_key: str = "chosen" + rejected_key: str = "rejected" # for unpaired preference dataset label_key: str = "" @dataclass -class DataConfig: - """Data config""" +class StorageConfig: + """Storage config.""" - data_workflow_url: Optional[str] = None + name: str = "" + storage_type: StorageType = StorageType.FILE + algorithm_type: Optional[AlgorithmType] = None # automatically set + path: Optional[str] = None - dataset_path: str = "" - train_split: str = "train" + # used for StorageType.FILE + split: str = "train" subset_name: Optional[str] = None - eval_split: Optional[str] = None # TODO: check data format - format_config: FormatConfig = field(default_factory=FormatConfig) + format: FormatConfig = field(default_factory=FormatConfig) + index: int = 0 + + # used for algorithm_type is None + task_type: TaskType = TaskType.EXPLORE + default_workflow_type: Optional[str] = None + default_reward_fn_type: Optional[str] = None + total_epochs: int = 1 # automatically set + # used for algorithm_type is None and TaskType.EVAL + eval_repeat_times: int = 1 # TODO + eval_temperature: float = 0.1 # TODO + + +@dataclass +class DataProcessorConfig: + """Data-Juicer config""" + + data_workflow_url: Optional[str] = None + + source_data_path: str = "" + format: FormatConfig = field(default_factory=FormatConfig) # data active iterator related - dataset_config: Dict[str, Any] = field(default_factory=dict) + load_kwargs: Dict[str, Any] = field(default_factory=dict) dj_config_path: Optional[str] = None # The path to Data-Juicer config file. dj_process_desc: Optional[ str @@ -74,11 +100,13 @@ class DataConfig: max_retry_times: int = 3 max_retry_interval: int = 1 + +@dataclass +class GlobalConfig: # downstream loading related total_epochs: int = 1 batch_size: int = 1 - default_workflow_type: str = "" - default_reward_fn_type: str = "" + eval_interval: int = 100 @dataclass @@ -104,36 +132,36 @@ class ClusterConfig: @dataclass -class DatasetConfig: - """The config for a dataset.""" +class ExplorerInput: + """Config for explorer input.""" - name: str - storage_type: StorageType - algorithm_type: AlgorithmType = AlgorithmType.PPO - path: Optional[str] = None - namespace: str = "" # automatically generated - kwargs: Dict[str, Any] = field(default_factory=dict) + taskset: StorageConfig = field(default_factory=StorageConfig) + eval_tasksets: List[StorageConfig] = field(default_factory=list) + default_workflow_type: Optional[str] = None + default_reward_fn_type: Optional[str] = None + + +@dataclass +class TrainerInput: + """Config for trainer input.""" + + experience_buffer: Optional[StorageConfig] = None + sft_warmup_dataset: Optional[StorageConfig] = None @dataclass class BufferConfig: """Config for experience buffer.""" - db_url: Optional[str] = None # Is deprecated, please set `buffer.train_dataset.path` instead. read_batch_size: int = 32 max_retry_times: int = 3 max_retry_interval: int = 1 - tokenizer_path: Optional[str] = None - pad_token_id: Optional[int] = None + tokenizer_path: Optional[str] = None # automatically set + pad_token_id: Optional[int] = None # automatically set - train_dataset: Optional[DatasetConfig] = None - sft_warmup_dataset: Optional[DatasetConfig] = None - - # remove in the future - prompt_type: PromptType = PromptType.MESSAGES - messages_key: str = "messages" - prompt_key: str = "prompt" - response_key: str = "response" + explorer_input: ExplorerInput = field(default_factory=ExplorerInput) + explorer_output: Optional[StorageConfig] = None # currently do not set + trainer_input: TrainerInput = field(default_factory=TrainerInput) @dataclass @@ -157,10 +185,6 @@ class ExplorerConfig: # for rollout tokneize chat_template: Optional[str] = None - # for evaluation - # TODO: remove trainer.eval_interval - eval_interval: int = 100 - # for vLLM tensor_parallel_size: int = 1 enable_prefix_caching: bool = False @@ -189,13 +213,12 @@ class ExplorerConfig: class TrainerConfig: trainer_type: str = "verl" trainer_config_path: str = "" - eval_interval: int = 100 save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb trainer_config: Any = field(default_factory=dict) # train algorithm - algorithm_type: AlgorithmType = AlgorithmType.PPO # automatically set + algorithm_type: AlgorithmType = AlgorithmType.PPO get_exp_strategy: Optional[str] = None # warmup config @@ -241,7 +264,8 @@ class Config: """Global Configuration""" mode: str = "both" # `explore`, `train`, `both` or `bench` - data: DataConfig = field(default_factory=DataConfig) + data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig) + global_config: GlobalConfig = field(default_factory=GlobalConfig) model: ModelConfig = field(default_factory=ModelConfig) cluster: ClusterConfig = field(default_factory=ClusterConfig) buffer: BufferConfig = field(default_factory=BufferConfig) @@ -255,16 +279,111 @@ def save(self, config_path: str) -> None: with open(config_path, "w", encoding="utf-8") as f: OmegaConf.save(self, f) - def _check_buffer(self) -> None: - if self.trainer.sft_warmup_steps > 0 and self.buffer.sft_warmup_dataset is None: + def _check_deprecated(self) -> None: + if self.synchronizer.sync_iteration_interval is not None: + logger.warning( + f"`synchronizer.sync_iteration_interval` is deprecated, please use `synchronizer.sync_interval` instead. " + f"And `synchronizer.sync_interval` will set to {self.synchronizer.sync_iteration_interval} instead." + ) + self.synchronizer.sync_interval = self.synchronizer.sync_iteration_interval + + if self.trainer.sft_warmup_iteration is not None: + logger.warning( + f"`trainer.sft_warmup_iteration` is deprecated, please use `trainer.sft_warmup_steps` instead. " + f"And `trainer.sft_warmup_steps` will be set to {self.trainer.sft_warmup_iteration} instead." + ) + self.trainer.sft_warmup_steps = self.trainer.sft_warmup_iteration + + def _check_interval(self) -> None: + assert self.synchronizer.sync_interval > 0 + + # check eval_interval + if ( + self.trainer.algorithm_type != AlgorithmType.DPO + and self.global_config.eval_interval % self.synchronizer.sync_interval != 0 + ): + self.global_config.eval_interval = ( + max(self.global_config.eval_interval // self.synchronizer.sync_interval, 1) + ) * self.synchronizer.sync_interval + logger.warning( + f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.global_config.eval_interval}." + ) + + # check save_interval + if ( + self.trainer.algorithm_type != AlgorithmType.DPO + and self.synchronizer.sync_method == SyncMethod.CHECKPOINT + ): + if self.trainer.save_interval != self.synchronizer.sync_interval: + logger.warning( + f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, " + f"`trainer.save_interval` will be set to " + f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + ) + self.trainer.save_interval = self.synchronizer.sync_interval + + def _check_buffer(self) -> None: # noqa: C901 + # check explorer_input + if self.mode != "train" and self.buffer.explorer_input.taskset.path is None: raise ValueError( - "buffer.sft_warmup_dataset is required when trainer.sft_warmup_steps > 0" + "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." ) - if self.buffer.db_url: + self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE + self.buffer.explorer_input.taskset.total_epochs = self.global_config.total_epochs + if self.buffer.explorer_input.taskset.default_workflow_type is None: + self.buffer.explorer_input.taskset.default_workflow_type = ( + self.buffer.explorer_input.default_workflow_type + ) + if self.buffer.explorer_input.taskset.default_reward_fn_type is None: + self.buffer.explorer_input.taskset.default_reward_fn_type = ( + self.buffer.explorer_input.default_reward_fn_type + ) + + for dataset in self.buffer.explorer_input.eval_tasksets: + dataset.task_type = TaskType.EVAL + if dataset.default_workflow_type is None: + dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type + if dataset.default_reward_fn_type is None: + dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type + + # check trainer_input.experience_buffer + if self.mode == "both": + if self.buffer.trainer_input.experience_buffer is None: + self.buffer.trainer_input.experience_buffer = StorageConfig( + name="experience_buffer", + storage_type=StorageType.QUEUE, + ) + logger.info( + f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" + ) + else: # TODO: to be check + if self.trainer.algorithm_type.is_dpo(): + if ( + self.buffer.trainer_input.experience_buffer is None + or not self.buffer.trainer_input.experience_buffer.path + ): + raise ValueError( + "`buffer.trainer_input.experience_buffer.path` is required when `trainer.algorithm_type == AlgorithmType.DPO`" + ) + self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type + + # set buffer.explorer_output + if self.buffer.explorer_output is None: + self.buffer.explorer_output = self.buffer.trainer_input.experience_buffer + + # check trainer_input.sft_warmup_dataset + if ( + self.trainer.sft_warmup_steps > 0 + and self.buffer.trainer_input.sft_warmup_dataset is None + ): raise ValueError( - "`buffer.db_url` is deprecated, please set `buffer.train_dataset.path` instead." + "buffer.trainer_input.sft_warmup_dataset is required when trainer.sft_warmup_steps > 0" ) + if self.buffer.trainer_input.sft_warmup_dataset is not None: + self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT + # set read_batch_size / pad_token_id / tokenizer_path + self.buffer.read_batch_size = self.global_config.batch_size * self.explorer.repeat_times if self.buffer.pad_token_id is None: from transformers import AutoTokenizer @@ -277,32 +396,10 @@ def _check_buffer(self) -> None: self.buffer.pad_token_id = 0 self.buffer.tokenizer_path = self.model.model_path - if self.mode == "both": - if self.buffer.train_dataset is None: - self.buffer.train_dataset = DatasetConfig( - name="experience_buffer", - storage_type=StorageType.QUEUE, - ) - logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") - else: # TODO: to be check - if self.mode == "train" and self.trainer.algorithm_type == AlgorithmType.DPO: - if self.buffer.train_dataset is None and self.data.dataset_path.strip(): - self.buffer.train_dataset = DatasetConfig( - name="dpo_train_dataset", - storage_type=StorageType.FILE, - ) - logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") - if self.buffer.train_dataset is None: - raise ValueError("buffer.train_dataset is required when mode is not 'both'") - self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type - self.buffer.train_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" - if self.buffer.sft_warmup_dataset is not None: - self.buffer.sft_warmup_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" - self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT - self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times - def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" + self._check_deprecated() + # check mode if self.mode not in ["explore", "train", "both", "bench"]: raise ValueError(f"Invalid mode: {self.mode}") @@ -316,13 +413,6 @@ def check_and_update(self) -> None: # noqa: C901 self.model.critic_model_path = self.model.model_path # check synchronizer - if self.synchronizer.sync_iteration_interval is not None: - logger.warning( - f"`synchronizer.sync_iteration_interval` is deprecated, please use `synchronizer.sync_interval` instead. " - f"And `synchronizer.sync_interval` will set to {self.synchronizer.sync_iteration_interval} instead." - ) - self.synchronizer.sync_interval = self.synchronizer.sync_iteration_interval - assert self.synchronizer.sync_interval > 0 self.synchronizer.explorer_world_size = ( self.explorer.engine_num * self.explorer.tensor_parallel_size ) @@ -343,35 +433,7 @@ def check_and_update(self) -> None: # noqa: C901 if self.synchronizer.sync_method == SyncMethod.NCCL and self.mode != "both": raise ValueError("`nccl` synchronization is only supported in both mode.") - # check eval_interval - if ( - self.trainer.algorithm_type != AlgorithmType.DPO - and self.trainer.eval_interval % self.synchronizer.sync_interval != 0 - ): - self.trainer.eval_interval = ( - max(self.trainer.eval_interval // self.synchronizer.sync_interval, 1) - ) * self.synchronizer.sync_interval - logger.warning( - f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.trainer.eval_interval}." - ) - if self.explorer.eval_interval != self.trainer.eval_interval: - self.explorer.eval_interval = self.trainer.eval_interval - logger.warning( - f"`explorer.eval_interval` is not equal to `trainer.eval_interval`; adjusted to the same value={self.trainer.eval_interval}." - ) - - # check save_interval - if ( - self.trainer.algorithm_type != AlgorithmType.DPO - and self.synchronizer.sync_method == SyncMethod.CHECKPOINT - ): - if self.trainer.save_interval != self.synchronizer.sync_interval: - logger.warning( - f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, " - f"`trainer.save_interval` will be set to " - f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." - ) - self.trainer.save_interval = self.synchronizer.sync_interval + self._check_interval() # check monitor if not self.monitor.cache_root_dir: @@ -389,17 +451,10 @@ def check_and_update(self) -> None: # noqa: C901 f"your checkpoint path: {self.model.checkpoint_path}" ) - if self.trainer.sft_warmup_iteration is not None: - logger.warning( - f"`trainer.sft_warmup_iteration` is deprecated, please use `trainer.sft_warmup_steps` instead. " - f"And `trainer.sft_warmup_steps` will be set to {self.trainer.sft_warmup_iteration} instead." - ) - self.trainer.sft_warmup_steps = self.trainer.sft_warmup_iteration - # check buffer self._check_buffer() # check and update trainer - if self.mode != "explore": + if self.mode in {"both", "train"}: if self.trainer.trainer_type == "verl": if self.trainer.trainer_config: from trinity.common.verl_config import veRLConfig diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 7fcc7b4f23..860cd39027 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -65,7 +65,6 @@ class StorageType(CaseInsensitiveEnum): class AlgorithmType(CaseInsensitiveEnum): """Algorithm Type.""" - ROLLOUT = "rollout" SFT = "sft" PPO = "ppo" GRPO = "grpo" @@ -86,10 +85,6 @@ def is_sft(self) -> bool: """Check if the algorithm is SFT.""" return self == AlgorithmType.SFT - def is_rollout(self) -> bool: - """Check if the algorithm is ROLLOUT.""" - return self == AlgorithmType.ROLLOUT - def is_dpo(self) -> bool: """Check if the algorithm is DPO.""" return self == AlgorithmType.DPO diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index eeecd4b741..87a2dd8ced 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -32,7 +32,7 @@ def create_rollout_models( else: raise ValueError(f"Unknown engine type: {config.explorer.engine_type}") - bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)] + bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) diff --git a/trinity/common/task.py b/trinity/common/task.py index 781e755739..b0582ab936 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -3,17 +3,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterator, Optional, Type +from typing import Any, Optional, Type -import datasets -from datasets import Dataset, load_dataset - -from trinity.common.config import Config, DataConfig +from trinity.common.config import Config from trinity.common.constants import TaskType -from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.rewards.reward_fn import RewardFn -from trinity.common.schema import RftDatasetModel -from trinity.common.workflows import WORKFLOWS from trinity.common.workflows.workflow import Workflow from trinity.utils.log import get_logger @@ -55,162 +49,3 @@ def to_workflow(self, model: Any, config: Config) -> Workflow: config=config, is_eval=self.task_type == TaskType.EVAL, ) - - -def task_generator( - dataset, - start_index: int, - config: DataConfig, - default_workflow: Optional[Type[Workflow]], - default_reward_fn: Optional[Type[RewardFn]], - task_type: Optional[TaskType], -) -> Iterator[Task]: - """Get a generator of tasks from the dataset.""" - for i, sample in enumerate(dataset): - if i < start_index: - continue - - task_desc = ( - sample[config.format_config.prompt_key] - if config.format_config.prompt_key in sample - else None - ) - truth = ( - sample[config.format_config.response_key] - if config.format_config.response_key in sample - else None - ) - workflow_class = ( - WORKFLOWS.get(sample[config.format_config.workflow_key]) - if config.format_config.workflow_key in sample - else default_workflow - ) - reward_fn = ( - REWARD_FUNCTIONS.get(sample[config.format_config.reward_fn_key]) - if config.format_config.reward_fn_key in sample - else default_reward_fn - ) - task = Task( - task_desc=task_desc, - truth=truth, - workflow=workflow_class, - reward_fn=reward_fn, - raw=sample, - task_type=task_type, - ) - yield task - - -def load_hf_dataset(config: DataConfig, split: str): - """Load a Hugging Face dataset with optional configuration name.""" - if config.subset_name is not None: - hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split) - else: - hf_dataset = load_dataset(config.dataset_path, split=split) - return hf_dataset - - -@dataclass -class TaskSet: - """A TaskSet class that defines a set of tasks and their associated reward functions.""" - - dataset: Any # the source huggingface dataset - config: DataConfig - reward_fn: Optional[Type[RewardFn]] = None - workflow: Optional[Type[Workflow]] = None - task_type: Optional[TaskType] = None - default_index: int = 0 - default_epoch: int = 0 - total_epochs: int = 1 - _tasks: Iterator[Task] = None - _index: int = 0 - _epoch: int = 0 - - @classmethod - def load( - cls, config: DataConfig, latest_task_index: int = 0, task_type: TaskType = None - ) -> TaskSet: - """Load the RFT taskset through config.""" - # disable datasets caching to avoid reuse old-version dataset - datasets.disable_caching() - if task_type == TaskType.EVAL: - assert config.eval_split is not None, "eval_split must be provided for eval taskset." - dataset = load_hf_dataset(config, config.eval_split) - else: # default - if task_type != TaskType.EVAL and config.db_url != "": - logger.info(f"Loading dataset from database with url: {config.db_url}") - db_type = config.db_url.split(":")[0] - db_name = config.db_url.split("/")[-1] - dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") - elif config.dataset_path != "": - logger.info(f"Loading dataset from local file with path: {config.dataset_path}.") - dataset = load_hf_dataset(config, config.train_split) - else: - raise ValueError("No dataset path or db url provided.") - datasets.enable_caching() - dataset_len = len(dataset) - default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) - default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) - return cls( - dataset=dataset, - config=config, - workflow=default_workflow_cls, - reward_fn=default_reward_fn_cls, - task_type=task_type, - default_index=latest_task_index % dataset_len, - default_epoch=latest_task_index // dataset_len, - total_epochs=config.total_epochs if task_type == TaskType.EXPLORE else 1, - ) - - def __iter__(self) -> Iterator[Task]: - """Initialize the iterator.""" - self._index = self.default_index - self._epoch = self.default_epoch - self._tasks = task_generator( - self.dataset, - self.default_index, - self.config, - self.workflow, - self.reward_fn, - self.task_type, - ) - return self - - @property - def index(self) -> int: - """Get the current index.""" - return self._index - - @property - def epoch(self) -> int: - """Get the current epoch.""" - return self._epoch - - def __next__(self) -> Task: - """Iterate through the tasks in the taskset.""" - if self._epoch >= self.total_epochs: - raise StopIteration - - try: - task = next(self._tasks) - if task.reward_fn is None: - task.reward_fn = self.reward_fn - if task.workflow is None: - task.workflow = self.workflow - self._index += 1 - return task - except StopIteration: - # Reset the task generator and increment the epoch - self._epoch += 1 - self._index += 1 - if self._epoch >= self.total_epochs: - raise StopIteration - self._tasks = task_generator( - self.dataset, - 0, - self.config, - self.workflow, - self.reward_fn, - self.task_type, - ) - return next(self._tasks) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 93a3c82d9e..d44440c826 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -176,6 +176,7 @@ class Critic: grad_clip: float = 0.0 cliprange_value: float = 0.0 checkpoint: Checkpoint = field(default_factory=Checkpoint) + rollout_n: int = 1 @dataclass @@ -288,22 +289,22 @@ def synchronize_config(self, config: Config) -> None: self.actor_rollout_ref.synchronizer = config.synchronizer self.buffer = config.buffer world_size = self.trainer.nnodes * self.trainer.n_gpus_per_node - if config.data.batch_size % world_size != 0: + if config.global_config.batch_size % world_size != 0: raise ValueError( - f"batch_size ({config.data.batch_size}) must be divisible by ({world_size})" + f"batch_size ({config.global_config.batch_size}) must be divisible by ({world_size})" ) # TODO: use dynamic read_batch_size to support multi-round scenarios # Get the experiences of one explore step - self.buffer.pad_token_id = config.buffer.pad_token_id self.trainer.project_name = config.monitor.project self.trainer.experiment_name = config.monitor.name - self.data.train_batch_size = config.data.batch_size + self.data.train_batch_size = config.global_config.batch_size self.trainer.default_local_dir = config.model.checkpoint_path self.trainer.sft_warmup_steps = config.trainer.sft_warmup_steps - self.actor_rollout_ref.actor.ppo_mini_batch_size = config.data.batch_size + self.actor_rollout_ref.actor.ppo_mini_batch_size = config.global_config.batch_size self.actor_rollout_ref.rollout.temperature = config.explorer.temperature self.actor_rollout_ref.rollout.n = config.explorer.repeat_times - self.critic.ppo_mini_batch_size = config.data.batch_size + self.critic.ppo_mini_batch_size = config.global_config.batch_size + self.critic.rollout_n = config.explorer.repeat_times self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type if config.trainer.algorithm_type == AlgorithmType.PPO: diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 40fc7aa3ce..4fa3a7c2f5 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -138,6 +138,7 @@ def __init__( ): if kwargs.get("reward_fn", None) is None: kwargs["reward_fn"] = MathRewardFn + if kwargs["reward_fn"] == MathRewardFn and kwargs.get("system_prompt", None) is None: kwargs[ "system_prompt" ] = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., diff --git a/trinity/data/controllers/active_iterator.py b/trinity/data/controllers/active_iterator.py index 014967e334..40da73384b 100644 --- a/trinity/data/controllers/active_iterator.py +++ b/trinity/data/controllers/active_iterator.py @@ -67,17 +67,17 @@ def __init__( # 2. input_keys: [prompt_key, response_key] if they are available # 3. field_names: [prompt_key, response_key] if they are available self.updated_op_args = { - "text_key": self.data_config.format_config.prompt_key, + "text_key": self.data_config.format.prompt_key, "input_keys": [ - self.data_config.format_config.prompt_key, + self.data_config.format.prompt_key, ], "field_names": [ - self.data_config.format_config.prompt_key, + self.data_config.format.prompt_key, ], } - if self.data_config.format_config.response_key != "": - self.updated_op_args["input_keys"].append(self.data_config.format_config.response_key) - self.updated_op_args["field_names"].append(self.data_config.format_config.response_key) + if self.data_config.format.response_key != "": + self.updated_op_args["input_keys"].append(self.data_config.format.response_key) + self.updated_op_args["field_names"].append(self.data_config.format.response_key) # flake8: noqa: C901 def run(self): diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index de6fa7281d..3e4af0fe12 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -6,7 +6,7 @@ from data_juicer.core.data.dj_dataset import Dataset from datasets import load_dataset -from trinity.common.config import DataConfig +from trinity.common.config import DataProcessorConfig from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.task import TaskSet from trinity.common.workflows import WORKFLOWS @@ -38,18 +38,18 @@ class RftDataset: def __init__( self, - data_config: DataConfig, + data_config: DataProcessorConfig, reward_schema: Union[str, Dict] = "default", track_lineage: bool = True, ): self.config = data_config - dataset_path = data_config.dataset_path - if not dataset_path: - raise ValueError("dataset_path is not specified in DJ config") - dataset_config = data_config.dataset_config - self.data = load_dataset(dataset_path, trust_remote_code=True, **dataset_config) + source_data_path = data_config.source_data_path + if not source_data_path: + raise ValueError("source_data_path is not specified in DJ config") + load_kwargs = data_config.load_kwargs + self.data = load_dataset(source_data_path, trust_remote_code=True, **load_kwargs) - self.format_config = data_config.format_config + self.format = data_config.format self.reward_schema = self._init_reward_schema(reward_schema) self.stats: Dict[str, Any] = {} diff --git a/trinity/data/core/dataset_db.py b/trinity/data/core/dataset_db.py index c557c0e020..f47b138995 100644 --- a/trinity/data/core/dataset_db.py +++ b/trinity/data/core/dataset_db.py @@ -6,7 +6,7 @@ from sqlalchemy.pool import NullPool from trinity.buffer.utils import retry_session -from trinity.common.config import DataConfig +from trinity.common.config import DataProcessorConfig from trinity.common.schema import Base, RftDatasetModel from trinity.data.core.dataset import RftDataset from trinity.utils.log import get_logger @@ -24,7 +24,7 @@ def rft_dataset_to_model(dataset: RftDataset) -> List[RftDatasetModel]: # - for other keys, we just need to check if they are in the dataset data = dataset.data features = data.features - content_key_mapping = dataset.format_config.__dict__ + content_key_mapping = dataset.format.__dict__ schema_keys = {key for key in RftDatasetModel.__dict__.keys() if not key.startswith("_")} for schema_key in schema_keys: key = schema_key @@ -44,7 +44,7 @@ def rft_dataset_to_model(dataset: RftDataset) -> List[RftDatasetModel]: class RftDatasetDB: - def __init__(self, config: DataConfig) -> None: + def __init__(self, config: DataProcessorConfig) -> None: self.db_url = config.db_url self.engine = create_engine(self.db_url, poolclass=NullPool) self.config = config diff --git a/trinity/data/readme.md b/trinity/data/readme.md index db1b3f443b..3294819f43 100644 --- a/trinity/data/readme.md +++ b/trinity/data/readme.md @@ -35,11 +35,11 @@ A data processing engine designed for Reinforcement Fine-Tuning (RFT) of Large L ```python from trinity.common.rewards import AccuracyReward from trinity.common.workflows import MathWorkflow -from trinity.common.config import DataConfig +from trinity.common.config import DataProcessorConfig from trinity.data.core.dataset import RftDataset from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter -data_config: DataConfig = ... +data_config: DataProcessorConfig = ... # initialize the dataset according to the data config dataset = RftDataset(data_config) @@ -47,8 +47,8 @@ dataset = RftDataset(data_config) # format it for the target data and training format # e.g. format for a boxed-tagged MATH data and RLHF format dataset.format([ - BoxedMathAnswerFormatter(data_config.format_config), - RLHFFormatter(data_config.format_config), + BoxedMathAnswerFormatter(data_config.format), + RLHFFormatter(data_config.format), ]) # convert to a task set with global reward function and workflow @@ -85,7 +85,7 @@ synth_data = synthesizer.process(clean_data) - You can either run `scripts/start_servers.py` or run `trinity/data/server.py` to start the data server. - Before running this config file, you need to replace the `username` and `db_name` with your own username and database name. - When requesting it, the server will load the dataset, clean it, compute priority scores from different dimensions, and export the result dataset to the database. -- Then you need to prepare the `data` section in the config file (e.g. [test_cfg.yaml](tests/test_configs/active_iterator_test_cfg.yaml)) +- Then you need to prepare the `data_processor` section in the config file (e.g. [test_cfg.yaml](tests/test_configs/active_iterator_test_cfg.yaml)) - For the `dj_config_path` argument in it, you can either specify a data-juicer config file path (e.g. [test_dj_cfg.yaml](tests/test_configs/active_iterator_test_dj_cfg.yaml)), or write the demand in `dj_process_desc` argument in natural language and our agent will help you to organize the data-juicer config. - Finally you can send requests to the data server to start an active iterator to process datasets in many ways: - Request with `curl`: `curl "http://127.0.0.1:5000/data_workflow?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"` diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 784d7b526c..6bed934f33 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -9,18 +9,14 @@ import torch from trinity.buffer import get_buffer_writer +from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import ( - ROLLOUT_WEIGHT_SYNC_GROUP_NAME, - SyncMethod, - TaskType, -) +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod from trinity.common.models import create_rollout_models from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, load_state_dict, ) -from trinity.common.task import TaskSet from trinity.explorer.runner_pool import RunnerPool from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger @@ -39,17 +35,16 @@ def __init__(self, config: Config): self.config = config self.models = create_rollout_models(config) self.experience_buffer = get_buffer_writer( - self.config.buffer.train_dataset, # type: ignore + self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) - self.taskset = TaskSet.load( - self.config.data, explorer_meta.get("latest_task_index", 0), TaskType.EXPLORE + self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) + self.taskset = get_buffer_reader( + self.config.buffer.explorer_input.taskset, self.config.buffer ) - if self.config.data.eval_split: - self.eval_taskset = TaskSet.load(self.config.data, task_type=TaskType.EVAL) - else: - self.eval_taskset = None - self.task_iter = None + self.eval_tasksets = [] + for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: + self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) self.runner_pool = self._init_runner_pool() self.monitor = Monitor( project=self.config.monitor.project, @@ -59,8 +54,10 @@ def __init__(self, config: Config): ) self.max_pending_task_num = self.config.explorer.runner_num self.max_waiting_steps = max(1, int(self.config.explorer.max_waiting_steps)) - self.batch_size = config.data.batch_size - self.update_interval = self.config.synchronizer.sync_interval * self.config.data.batch_size + self.batch_size = config.global_config.batch_size + self.update_interval = ( + self.config.synchronizer.sync_interval * self.config.global_config.batch_size + ) self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT ) @@ -164,7 +161,7 @@ def explore(self) -> None: if not explore_status: break self.sync_weight() - if explore_iter % self.config.explorer.eval_interval == 0: + if explore_iter % self.config.global_config.eval_interval == 0: self.eval() self.logger.info("Evaluation finished.") self.logger.info("Explorer finished.") @@ -179,17 +176,17 @@ def explore_one_period(self) -> Tuple[bool, int]: explore_status: whether there are more tasks to explore. explore_step_num: the number of explore steps """ - if self.task_iter is None: - self.task_iter = iter(self.taskset) - task_num_per_period = self.config.synchronizer.sync_interval * self.config.data.batch_size + task_num_per_period = ( + self.config.synchronizer.sync_interval * self.config.global_config.batch_size + ) st = time.time() all_metrics = defaultdict(list) # submit tasks of this step try: - tasks = [next(self.task_iter) for _ in range(task_num_per_period)] # type: ignore - self.runner_pool.run_tasks(tasks) + tasks = [self.taskset.read() for _ in range(task_num_per_period)] + self.runner_pool.run_tasks(tasks) # type: ignore except StopIteration: self.experience_buffer.finish() self.logger.warning("No more tasks in the task set. Stop exploring.") @@ -205,7 +202,7 @@ def explore_one_period(self) -> Tuple[bool, int]: self.logger.error(f"Error when running task: {status.message}") try: # submit another task to replace the failed task - self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore + self.runner_pool.run_tasks(self.taskset.read()) except StopIteration: self.logger.warning("No more tasks in the task set. Stop exploring.") return False, self.step_num @@ -230,30 +227,37 @@ def explore_one_period(self) -> Tuple[bool, int]: def eval(self) -> Tuple[bool, int]: """Evaluation on all evaluation data samples.""" - if self.eval_taskset is None: + if len(self.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return True, self.step_num self.logger.info("Evaluation started.") - st = time.time() - all_metrics = defaultdict(list) - - tasks = [task for task in self.eval_taskset] - self.runner_pool.run_tasks(tasks) + all_st = time.time() + log_metrics = {} + for eval_taskset in self.eval_tasksets: + st = time.time() + all_metrics = defaultdict(list) - while self.runner_pool.has_next(): - # TODO: use unordered queue to avoid blocking - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) + def wait(): + status_list = self.runner_pool.get_next_unorder() + if not isinstance(status_list, list): + status_list = [status_list] + for status in status_list: + if not status.ok: + self.logger.error(f"Error when running task: {status.message}") + else: + for metric_name, metric_value in status.metric.items(): + all_metrics[metric_name].append(metric_value) - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore - log_metrics["eval/total_time"] = time.time() - st + for _ in range(len(eval_taskset)): # type: ignore + if not self.runner_pool.has_free(): + wait() + self.runner_pool.run_tasks([eval_taskset.read()]) # type: ignore + while self.runner_pool.has_next(): + wait() + metrics = self.monitor.calculate_metrics(all_metrics, prefix=f"eval/{eval_taskset.name}") # type: ignore + log_metrics.update(metrics) + log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st + log_metrics["eval/total_time"] = time.time() - all_st self.monitor.log(log_metrics, step=self.step_num) # type: ignore return True, self.step_num diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b2c0c11804..86e4003df4 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -34,7 +34,7 @@ class WorkflowRunner: def __init__(self, config: Config, model: InferenceModel) -> None: self.config = config self.experience_buffer = get_buffer_writer( - self.config.buffer.train_dataset, # type: ignore + self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) self.model = model diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 0045deac25..1f55b3c702 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -61,23 +61,29 @@ def _init_default_config(self): "trainer_gpu_num": 6, "max_prompt_tokens": 1024, "max_response_tokens": 1024, - # Data Configs + # Global Configs "total_epochs": 20, "_train_batch_size_per_gpu": 16, "train_batch_size": 96, - "dataset_path": "", - "subset_name": None, - "train_split": "train", - "eval_split": "", - "prompt_key": "question", - "response_key": "answer", + "eval_interval": 1000, + # Taskset Configs + "taskset_path": "", + "taskset_subset_name": None, + "taskset_split": "train", + "taskset_prompt_key": "question", + "taskset_response_key": "answer", + # Eval Taskset Configs + # TODO + # Task Workflow Configs "default_workflow_type": "math_workflow", "default_reward_fn_type": "math_reward", - # Buffer Configs - "_is_dpo_storage_type": StorageType.FILE.value, + # Experience Buffer Configs + "_dpo_storage_type": StorageType.FILE.value, "_not_dpo_storage_type": StorageType.QUEUE.value, "storage_type": StorageType.QUEUE.value, - "train_dataset_path": "", + "_dpo_experience_buffer_path": "", + "_not_dpo_experience_buffer_path": "", + "experience_buffer_path": "", "buffer_max_retry_times": 3, "max_retry_interval": 1, "dpo_dataset_train_split": "train", @@ -85,6 +91,7 @@ def _init_default_config(self): "dpo_dataset_prompt_key": "prompt", "dpo_dataset_chosen_key": "chosen", "dpo_dataset_rejected_key": "rejected", + # SFT Warmup Dataset Configs "sft_warmup_dataset_path": "", "sft_warmup_train_split": "train", "sft_warmup_prompt_type": PromptType.MESSAGES.value, @@ -98,7 +105,6 @@ def _init_default_config(self): "_grouped_adv_repeat_times": 2, "_not_grouped_adv_repeat_times": 1, "repeat_times": 1, - "eval_interval": 1000, "tensor_parallel_size": 1, "enable_prefix_caching": False, "enforce_eager": True, @@ -299,26 +305,25 @@ def _check_train_batch_size(self): self.unfinished_fields.add("train_batch_size") st.warning(self._str_for_train_batch_size) - def _set_dataset_path(self): - st.text_input("Dataset Path", key="dataset_path") - if not st.session_state["dataset_path"].strip(): - self.unfinished_fields.add("dataset_path") - st.warning("Please input dataset path.") + def _set_taskset_path(self): + st.text_input("Taskset Path", key="taskset_path") + if not st.session_state["taskset_path"].strip(): + self.unfinished_fields.add("taskset_path") + st.warning("Please input taskset path.") - def _set_dataset_args(self): - if st.session_state["dataset_path"] and "://" not in st.session_state["dataset_path"]: - subset_name_col, train_split_col, eval_split_col = st.columns(3) + def _set_taskset_args(self): + if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]: + subset_name_col, split_col = st.columns(2) subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", key="subset_name" - ) - train_split_col.text_input( - "Train Split :orange-badge[(Needs review)]", key="train_split" + "Subset Name :orange-badge[(Needs review)]", key="taskset_subset_name" ) - eval_split_col.text_input("Eval Split :orange-badge[(Needs review)]", key="eval_split") + split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="prompt_key") + prompt_key_col.text_input( + "Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key" + ) response_key_col.text_input( - "Response Key :orange-badge[(Needs review)]", key="response_key" + "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" ) def _set_default_workflow_type(self): @@ -349,7 +354,7 @@ def _set_default_reward_fn_type(self): def _set_storage_type(self): if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["storage_type"] = st.session_state["_is_dpo_storage_type"] + st.session_state["storage_type"] = st.session_state["_dpo_storage_type"] storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: st.session_state["storage_type"] = st.session_state["_not_dpo_storage_type"] @@ -357,7 +362,7 @@ def _set_storage_type(self): def on_change(): if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["_is_dpo_storage_type"] = st.session_state["storage_type"] + st.session_state["_dpo_storage_type"] = st.session_state["storage_type"] else: st.session_state["_not_dpo_storage_type"] = st.session_state["storage_type"] @@ -368,19 +373,48 @@ def on_change(): on_change=on_change, ) - def _set_train_dataset_path(self): # TODO + def _set_experience_buffer_path(self): # TODO + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["experience_buffer_path"] = st.session_state[ + "_dpo_experience_buffer_path" + ] + title = "DPO Dataset Path" + help_msg = r"""This path to DPO dataset, + +if `storage_type == StorageType.FILE`, this should be a path to a file, + +if `storage_type == StorageType.SQL`, this should be a path to database.""" + else: + st.session_state["experience_buffer_path"] = st.session_state[ + "_not_dpo_experience_buffer_path" + ] + title = "Experience Buffer Path" + help_msg = r"""This path is used for `trainer`, + +if `storage_type == StorageType.QUEUE`, default to `None`, + +if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`.""" + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_dpo_experience_buffer_path"] = st.session_state[ + "experience_buffer_path" + ] + else: + st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[ + "experience_buffer_path" + ] + st.text_input( - "Train Dataset Path", - key="train_dataset_path", - help=r"This path is used for `trainer`, " - r"if `storage_type == StorageType.QUEUE`, default to `None`, " - r"if `storage_type == StorageType.FILE`, this should be a path to a file, " - r"if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`.", - ) - if st.session_state["storage_type"] == StorageType.FILE.value: - if not st.session_state["train_dataset_path"].strip(): - self.unfinished_fields.add("train_dataset_path") - st.warning("Please input train dataset path.") + title, + key="experience_buffer_path", + help=help_msg, + on_change=on_change, + ) + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if not st.session_state["experience_buffer_path"].strip(): + self.unfinished_fields.add("experience_buffer_path") + st.warning("Please input DPO dataset path.") def _set_buffer_max_retry_times(self): st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1) @@ -1025,7 +1059,7 @@ def beginner_mode(self): self._set_checkpoint_path() - self._set_dataset_path() + self._set_taskset_path() self._set_configs_with_st_columns(["algorithm_type", "sft_warmup_steps", "monitor_type"]) if st.session_state["sft_warmup_steps"] > 0: @@ -1055,7 +1089,7 @@ def beginner_mode(self): ) if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - self._set_dataset_args() + self._set_taskset_args() else: self._set_dpo_dataset_kwargs() @@ -1094,15 +1128,30 @@ def _expert_model_part(self): self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) def _expert_buffer_part(self): - self._set_configs_with_st_columns(["total_epochs", "train_batch_size", "storage_type"]) + self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) self._check_train_batch_size() - self._set_dataset_path() - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - self._set_dataset_args() + with st.expander("Taskset Configs", expanded=True): + self._set_taskset_path() + self._set_taskset_args() else: - self._set_dpo_dataset_kwargs() + with st.expander("DPO Dataset Configs", expanded=True): + self._set_experience_buffer_path() + self._set_dpo_dataset_kwargs() + + with st.expander("Eval Tasksets Configs", expanded=True): + # TODO: + pass + + with st.expander("SFT Dataset Configs"): + self._set_sft_warmup_dataset_path() + self._set_sft_warmup_dataset_args() + + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + with st.expander("Experiences Buffer Configs", expanded=True): + self._set_storage_type() + self._set_experience_buffer_path() self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) @@ -1110,9 +1159,6 @@ def _expert_buffer_part(self): with self.buffer_advanced_tab: self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"]) - self._set_sft_warmup_dataset_path() - self._set_sft_warmup_dataset_args() - def _expert_explorer_part(self): self._set_configs_with_st_columns( ["engine_type", "engine_num", "tensor_parallel_size", "repeat_times"] @@ -1233,7 +1279,7 @@ def _expert_verl_trainer_part(self): def expert_mode(self): tab2func = { "Model": self._expert_model_part, - "Data": self._expert_buffer_part, + "Buffer": self._expert_buffer_part, "Explorer and Synchronizer": self._expert_explorer_part, "Trainer": self._expert_trainer_part, } @@ -1492,15 +1538,19 @@ def generate_config(self): ) if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - train_dataset_path = ( - st.session_state["train_dataset_path"].strip() - if st.session_state["train_dataset_path"].strip() - else st.session_state["dataset_path"].strip() - ) + pass + # experience_buffer_path = ( + # st.session_state["experience_buffer_path"].strip() + # if st.session_state["experience_buffer_path"].strip() + # else st.session_state["dataset_path"].strip() + # ) else: # not dpo algorithms - train_dataset_path = st.session_state["train_dataset_path"].strip() - if not train_dataset_path and st.session_state["storage_type"] == StorageType.SQL.value: - train_dataset_path = f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" + experience_buffer_path = st.session_state["experience_buffer_path"].strip() + if ( + not experience_buffer_path + and st.session_state["storage_type"] == StorageType.SQL.value + ): + experience_buffer_path = f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" sft_storage_type = ( StorageType.SQL.value @@ -1534,18 +1584,10 @@ def generate_config(self): if st.session_state.config_generated: config = { "mode": st.session_state["mode"], - "data": { + "global_config": { "total_epochs": st.session_state["total_epochs"], "batch_size": st.session_state["train_batch_size"], - "dataset_path": st.session_state["dataset_path"], - "default_workflow_type": st.session_state["default_workflow_type"], - "default_reward_fn_type": st.session_state["default_reward_fn_type"], - "train_split": st.session_state["train_split"], - "eval_split": st.session_state["eval_split"], - "format_config": { - "prompt_key": st.session_state["prompt_key"], - "response_key": st.session_state["response_key"], - }, + "eval_interval": st.session_state["eval_interval"], }, "model": { "model_path": st.session_state["model_path"], @@ -1561,21 +1603,27 @@ def generate_config(self): "buffer": { "max_retry_times": st.session_state["buffer_max_retry_times"], "max_retry_interval": st.session_state["max_retry_interval"], - "train_dataset": { - "name": "experience_buffer", # TODO - "storage_type": st.session_state["storage_type"], - "path": train_dataset_path, + "explorer_input": { + "taskset": { + "name": "taskset", + "storage_type": StorageType.FILE.value, + "path": st.session_state["taskset_path"], + "split": st.session_state["taskset_split"], + "subset_name": st.session_state["taskset_subset_name"], + "format": { + "prompt_key": st.session_state["taskset_prompt_key"], + "response_key": st.session_state["taskset_response_key"], + }, + }, + "eval_tasksets": [], # TODO: add eval tasksets + "default_workflow_type": st.session_state["default_workflow_type"], + "default_reward_fn_type": st.session_state["default_reward_fn_type"], }, - "sft_warmup_dataset": { - "name": "sft_warmup_dataset", - "storage_type": sft_storage_type, - "path": st.session_state["sft_warmup_dataset_path"], - "kwargs": { - "train_split": st.session_state["sft_warmup_train_split"], - "prompt_type": st.session_state["sft_warmup_prompt_type"], - "messages_key": st.session_state["sft_warmup_messages_key"], - "prompt_key": st.session_state["sft_warmup_prompt_key"], - "response_key": st.session_state["sft_warmup_response_key"], + "trainer_input": { + "experience_buffer": { + "name": "experience_buffer", + "storage_type": st.session_state["storage_type"], + "path": experience_buffer_path, }, }, }, @@ -1585,7 +1633,6 @@ def generate_config(self): "runner_num": st.session_state["runner_num"], "repeat_times": st.session_state["repeat_times"], # "chat_template": None, # TODO: add chat template - "eval_interval": st.session_state["eval_interval"], "tensor_parallel_size": st.session_state["tensor_parallel_size"], "enable_prefix_caching": st.session_state["enable_prefix_caching"], "enforce_eager": st.session_state["enforce_eager"], @@ -1615,7 +1662,6 @@ def generate_config(self): "algorithm_type": st.session_state["algorithm_type"], "trainer_config": trainer_config, "sft_warmup_steps": st.session_state["sft_warmup_steps"], - "eval_interval": st.session_state["eval_interval"], "save_interval": st.session_state["save_interval"], }, "monitor": { @@ -1626,12 +1672,26 @@ def generate_config(self): } if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - config["buffer"]["train_dataset"]["kwargs"] = { - "dpo_dataset_train_split": st.session_state["dpo_dataset_train_split"], - "dpo_dataset_prompt_type": st.session_state["dpo_dataset_prompt_type"], - "dpo_dataset_prompt_key": st.session_state["dpo_dataset_prompt_key"], - "dpo_dataset_chosen_key": st.session_state["dpo_dataset_chosen_key"], - "dpo_dataset_rejected_key": st.session_state["dpo_dataset_rejected_key"], + experience_buffer = config["buffer"]["trainer_input"]["experience_buffer"] + experience_buffer["split"] = st.session_state["dpo_dataset_train_split"] + experience_buffer["format"] = { + "prompt_type": st.session_state["dpo_dataset_prompt_type"], + "prompt_key": st.session_state["dpo_dataset_prompt_key"], + "chosen_key": st.session_state["dpo_dataset_chosen_key"], + "rejected_key": st.session_state["dpo_dataset_rejected_key"], + } + if st.session_state["sft_warmup_dataset_path"].strip(): + config["buffer"]["trainer_input"]["sft_warmup_dataset"] = { + "name": "sft_warmup_dataset", + "storage_type": sft_storage_type, + "path": st.session_state["sft_warmup_dataset_path"], + "split": st.session_state["sft_warmup_train_split"], + "format": { + "prompt_type": st.session_state["sft_warmup_prompt_type"], + "messages_key": st.session_state["sft_warmup_messages_key"], + "prompt_key": st.session_state["sft_warmup_prompt_key"], + "response_key": st.session_state["sft_warmup_response_key"], + }, } st.session_state.config_generated = True st.header("Generated Config File") diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 607b01dc7d..f6edb4e6fb 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -26,12 +26,12 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.train_buffer = get_buffer_reader( - self.config.buffer.train_dataset, # type: ignore + self.config.buffer.trainer_input.experience_buffer, # type: ignore self.config.buffer, ) self.sft_warmup_buffer = ( get_buffer_reader( - self.config.buffer.sft_warmup_dataset, # type: ignore + self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore self.config.buffer, ) if self.config.trainer.sft_warmup_steps > 0 diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 57e9849f9d..3cd1a53e13 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -866,6 +866,7 @@ def __init__(self, config): self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config + self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size )