diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 27b5fb26bf..f0ce9b1eb9 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -8,7 +8,8 @@ In this example, you will learn how to apply the data processor of Trinity-RFT t 2. how to configure the data processor 3. what the data processor can do -Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md). +Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md), +and store the base url and api key in the environment variables `OPENAI_BASE_URL` and `OPENAI_API_KEY` for some agentic or API-model usages if necessary. ### Data Preparation @@ -103,8 +104,6 @@ If you are familiar with Data-Juicer, you will realize that Data-Juicer provides # This is a Data-Juicer data processing recipe project_name: 'gsm-8k-difficulty' -export_path: '/path/to/the/result/processed-dataset.jsonl' - process: - llm_difficulty_score_filter: api_or_hf_model: "qwen2.5-72b-instruct" # use "qwen2.5-72b-instruct" to calculate the difficulty scores. @@ -143,7 +142,7 @@ And you can set the `clean_strategy` to 'iterative' to get a better dataset. -All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml). +All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml). @@ -167,6 +166,99 @@ trinity run --config If you follow the steps above, Trinity-RFT will send a request to the data processor server, the data active iterator will be activated, compute difficulty scores for each sample in the raw dataset, and rank the dataset according to difficulty scores. After that, the data processor server stores the result dataset into the output buffer, when exploring begins, it will load the prepared dataset and continue the downstream steps. +## Example: Data Processor for Experience Pipeline + +In this example, you will learn how to apply the data processor of Trinity-RFT to reshape rewards of experiences after exploring. This example takes GSM-8K dataset as the example dataset to figure out how to reshape rewards of experiences from the explorer before sent to the trainer from a view of the quality of generated responses. + +Before getting started, you need to prepare the main environment of Trinity-RFT and start server for the data processor according to the first subsection in the previous example. + +### Configure the Data Processor + +In this example, assume that you need to add an extra reward item to the experiences outputted by the explorer, which access the quality scores of the experiences. So you can set the `experience_pipeline` config like the following example: + +```yaml +data_processor: + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # experience pipeline related + experience_pipeline: + # I/O buffers + input_buffers: + - name: gsm8k_exp_output + output_buffer: + name: reshaped_gsm8k_exp_input + # format mapping + format: + reward_key: 'reward' # the key name of the reward in the experience + # data active iterator related + dj_config_path: 'examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml' + clean_strategy: 'iterative' + # reward shaping + reward_shaping: + - stats_key: 'llm_quality_score' + op_type: ADD + weight: 1.0 + +# the buffer config +buffer: + ... + explorer_output: + name: gsm8k_exp_output + storage_type: queue + path: 'sqlite:///gsm8k_exp_output.db' + trainer_input: + experience_buffer: + name: reshaped_gsm8k_exp_input + storage_type: queue + path: 'sqlite:///reshaped_gsm8k_exp_input.db' +``` + +Here you can set the input/output buffers for the experience pipeline, and some other items about reward shaping: + ++ `data_processor_url`: the URL of the data processor service, which is started in the previous step. ++ `experience_pipeline`: the configs for the experience pipeline. Experience pipeline is used to process the experiences outputted by the explorer, such as reward shaping, data filtering and augmentation. It consists of several inner configs: + + `input_buffers`: the input buffers for the experience pipeline. It usually loads from the explorer output buffer, so we need to specify the `explorer_output` in the `buffer` config, and here we only need to specify the name that is aligned with the `explorer_output`. It allows multiple input buffers, but for now, we only need to specify one. + + `output_buffer`: the output buffer for the experience pipeline. It usually writes results to the input buffer of trainer, so we only need to the specify the buffer name that is aligned with the `trainer_input` in the `buffer` config. + + `format`: some dataset format config items, which are used to map original data field names to unified ones. Here we only need to specify the field name to store the original reward information. + + `reward_shaping`: the method to reshape the reward. Usually we use some stats computed by operators in Data-Juicer as new reward items. It's a list that allows multiple methods to reshape rewards. Each item in the list has the following config items: + + `stats_key`: which stats to use as the new reward item. + + `op_type`: the operator to apply the new reward item to the original reward. For now, ["ADD", "SUB", "MUL", "DIV"] are supported. + + `weight`: the weight of the new reward item. + +In addition, there are several config items related to the data active iterator in `experience_pipeline` part, which is used to compute stats used to reshape rewards. This part is similar to the `task_pipeline` part in the previous example. The Data-Juicer config used here is: +```yaml +# This is a Data-Juicer data processing recipe +project_name: 'gsm-8k-experience-quality' + +np: 32 + +process: + - llm_quality_score_filter: + api_or_hf_model: "qwen2.5-32b-instruct" # use "qwen2.5-32b-instruct" to calculate the quality scores. + min_score: 0.0 + input_keys: ["prompt_text", "prompt_text"] # set input_keys and field_names to the existing key names in gsm-8k. Here calculating the difficulty scores according to both questions and answers. + field_names: ["prompt", "response"] +``` + +All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml). + +### Exploring & Training +After preparing the config files of Trinity-RFT, you can start your ray cluster and run the RFT process including the data active iterator part with the following commands: + +```shell +# start the ray cluster +# on master node +ray start --head +# on worker nodes +ray start --address= + +# run RFT +trinity run --config +``` + +If you follow the steps above, Trinity-RFT will send a request to the data processor server and prepare the experience pipeline. +It will watch the explorer output buffer. Once there is a new batch of experience, the data processor will compute stats for the experience and reshape the rewards. Then it writes the reshaped experience to the trainer input buffer for training. + + ## Example: Human in the Loop Sometimes, you might need to involve human feedbacks for some raw data. In this example, you will learn how to annotate raw data to get a better dataset before training. This example takes an example Q&A dataset and tries to select the chosen and rejected ones for DPO method. diff --git a/examples/grpo_gsm8k_experience_pipeline/README.md b/examples/grpo_gsm8k_experience_pipeline/README.md new file mode 100644 index 0000000000..91ca63d391 --- /dev/null +++ b/examples/grpo_gsm8k_experience_pipeline/README.md @@ -0,0 +1,7 @@ +# GRPO on GSM8K dataset with Experience Pipeline + +This example shows the usage of GRPO on the GSM8K dataset, with a experience pipeline to reshape the rewards of experiences while training. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml b/examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml new file mode 100644 index 0000000000..59b786abec --- /dev/null +++ b/examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml @@ -0,0 +1,11 @@ +# This is a Data-Juicer data processing recipe +project_name: 'gsm-8k-experience-quality' + +np: 32 + +process: + - llm_quality_score_filter: + api_or_hf_model: "qwen2.5-32b-instruct" # use "qwen2.5-32b-instruct" to calculate the quality scores. + min_score: 0.0 + input_keys: ["prompt_text", "prompt_text"] # set input_keys and field_names to the existing key names in gsm-8k. Here calculating the difficulty scores according to both questions and answers. + field_names: ["prompt", "response"] diff --git a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml new file mode 100644 index 0000000000..25291a49cc --- /dev/null +++ b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml @@ -0,0 +1,89 @@ +project: "Trinity-RFT-gsm8k-experience-pipeline" +name: "qwen2.5-1.5B-gsm8k-experience-pipeline" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +data_processor: + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # experience pipeline related + experience_pipeline: + # I/O buffers + input_buffers: + - name: gsm8k_exp_output + output_buffer: + name: reshaped_gsm8k_exp_input + # format mapping + format: + reward_key: 'reward' # the key name of the reward in the experience + # data active iterator related + dj_config_path: 'examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml' + clean_strategy: 'iterative' + # reward shaping + reward_shaping: + - stats_key: 'llm_quality_score' + op_type: ADD + weight: 1.0 + +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 256 + max_response_tokens: 1024 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + explorer_output: + name: gsm8k_exp_output + storage_type: queue + path: 'sqlite:///gsm8k_exp_output.db' + trainer_input: + experience_buffer: + name: reshaped_gsm8k_exp_input + storage_type: queue + path: 'sqlite:///reshaped_gsm8k_exp_input.db' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml b/examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml new file mode 100644 index 0000000000..fc44fdad94 --- /dev/null +++ b/examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml @@ -0,0 +1,50 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/tests/common/config_test.py b/tests/common/config_test.py index da4fd914a0..2c5ef463b3 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -41,7 +41,9 @@ def test_all_examples_are_valid(self): for example_name in os.listdir(example_dir): for filename in os.listdir(os.path.join(example_dir, example_name)): if filename.endswith(".yaml") and not ( - filename.startswith("train_") or filename.startswith("verl_") + filename.startswith("train_") + or filename.startswith("verl_") + or filename.startswith("dj_") ): print(f"Checking config: {filename}") config_path = os.path.join(example_dir, example_name, filename) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 11ad2d1d4e..947c4d4ecb 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -5,8 +5,8 @@ import torch +from trinity.buffer.schema.sql_schema import ExperienceModel from trinity.common.experience import Experience, Experiences -from trinity.common.schema import ExperienceModel db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") diff --git a/tests/data/core/formatter_test.py b/tests/data/core/formatter_test.py index dbb73ed971..a4453efd79 100644 --- a/tests/data/core/formatter_test.py +++ b/tests/data/core/formatter_test.py @@ -47,7 +47,7 @@ def test_init(self): self.assertEqual(formatter.config.solution_key, "solution") self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ") # test for default configs - self.assertEqual(formatter.config.reward_key, "") + self.assertEqual(formatter.config.reward_key, "reward") self.assertEqual(formatter.config.chosen_key, "chosen") self.assertEqual(formatter.config.rejected_key, "rejected") self.assertEqual(formatter.config.label_key, "") diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 15b669d61c..1b3ba1f4bb 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -8,7 +8,13 @@ import ray -from trinity.common.config import Config, DataPipelineConfig, load_config +from trinity.common.config import Config, load_config +from trinity.common.constants import DataProcessorPipelineType +from trinity.data.utils import ( + activate_data_processor, + stop_data_processor, + validate_data_pipeline, +) from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -147,61 +153,6 @@ def both(config: Config) -> None: trainer.shutdown.remote() -def activate_data_module(data_processor_url: str, config_path: str): - """Check whether to activate data module and preprocess datasets.""" - from trinity.cli.client import request - - logger.info(f"Activating data module of {data_processor_url}...") - res = request( - url=data_processor_url, - configPath=config_path, - ) - if res["return_code"] != 0: - logger.error(f"Failed to activate data module: {res['return_msg']}.") - return - - -def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str): - """ - Check if the data pipeline is valid. The config should: - 1. Non-empty input buffer - 2. Different input/output buffers - - :param data_pipeline_config: the input data pipeline to be validated. - :param pipeline_type: the type of pipeline, should be one of ["task", "experience"] - """ - input_buffers = data_pipeline_config.input_buffers - output_buffer = data_pipeline_config.output_buffer - # common checks - # check if the input buffer list is empty - if len(input_buffers) == 0: - logger.warning("Empty input buffers in the data pipeline. Won't activate it.") - return False - # check if the input and output buffers are different - input_buffer_names = [buffer.name for buffer in input_buffers] - if output_buffer.name in input_buffer_names: - logger.warning("Output buffer exists in input buffers. Won't activate it.") - return False - if pipeline_type == "task": - # task pipeline specific - # "raw" field should be True for task pipeline because the data source must be raw data files - for buffer in input_buffers: - if not buffer.raw: - logger.warning( - 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' - ) - return False - elif pipeline_type == "experience": - # experience pipeline specific - raise NotImplementedError("experience_pipeline is not implemented yet.") - else: - logger.warning( - f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].' - ) - return False - return True - - def run(config_path: str, dlc: bool = False, plugin_dir: str = None): load_plugins(plugin_dir) config = load_config(config_path) @@ -210,21 +161,27 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): # try to activate task pipeline for raw data data_processor_config = config.data_processor if ( - data_processor_config.data_processor_url - and data_processor_config.task_pipeline - and validate_data_pipeline(data_processor_config.task_pipeline, "task") + data_processor_config.data_processor_url is not None + and data_processor_config.task_pipeline is not None + and validate_data_pipeline( + data_processor_config.task_pipeline, DataProcessorPipelineType.TASK + ) ): - activate_data_module( - f"{data_processor_config.data_processor_url}/task_pipeline", config_path + activate_data_processor( + f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.TASK.value}", + config_path, ) # try to activate experience pipeline for experiences if ( - data_processor_config.data_processor_url - and data_processor_config.experience_pipeline - and validate_data_pipeline(data_processor_config.experience_pipeline, "experience") + data_processor_config.data_processor_url is not None + and data_processor_config.experience_pipeline is not None + and validate_data_pipeline( + data_processor_config.experience_pipeline, DataProcessorPipelineType.EXPERIENCE + ) ): - activate_data_module( - f"{data_processor_config.data_processor_url}/experience_pipeline", config_path + activate_data_processor( + f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", + config_path, ) if dlc: from trinity.utils.dlc_utils import setup_ray_cluster @@ -257,6 +214,10 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): stop_ray_cluster(namespace=config.ray_namespace) + # stop all pipelines + if data_processor_config.data_processor_url is not None: + stop_data_processor(data_processor_config.data_processor_url) + def studio(port: int = 8501): from streamlit.web import cli as stcli diff --git a/trinity/common/config.py b/trinity/common/config.py index 540b004db4..72e9964857 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -9,6 +9,7 @@ from trinity.common.constants import ( EXPLORER_NAME, TRAINER_NAME, + OpType, PromptType, ReadStrategy, StorageType, @@ -38,10 +39,10 @@ class FormatConfig: reward_fn_key: str = "" workflow_key: str = "" # for math dataset - solution_key: str = "" + solution_key: str = "solution" # for reward dataset - reward_key: str = "" + reward_key: str = "reward" # for dpo dataset chosen_key: str = "chosen" @@ -106,6 +107,15 @@ class StorageConfig: task_type: TaskType = TaskType.EXPLORE +@dataclass +class RewardShapingConfig: + """Config for reward shaping.""" + + stats_key: str = "" + op_type: OpType = OpType.ADD + weight: float = 1.0 + + @dataclass class DataPipelineConfig: """Config for data pipeline.""" @@ -130,6 +140,9 @@ class DataPipelineConfig: priority_weights: Optional[Dict[str, float]] = None data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"] + # reward shaping related, only available for experience pipeline + reward_shaping: Optional[List[RewardShapingConfig]] = field(default_factory=list) + @dataclass class DataProcessorConfig: @@ -439,6 +452,8 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.explorer_input.taskset.format.reply_prefix = ( self.buffer.explorer_input.reply_prefix ) + if self.buffer.explorer_input.taskset.ray_namespace is None: + self.buffer.explorer_input.taskset.ray_namespace = self.ray_namespace remained_tasksets = [] for idx, dataset in enumerate(self.buffer.explorer_input.eval_tasksets): @@ -456,6 +471,8 @@ def _check_buffer(self) -> None: # noqa: C901 dataset.format.system_prompt = self.buffer.explorer_input.system_prompt if dataset.format.reply_prefix is None: dataset.format.reply_prefix = self.buffer.explorer_input.reply_prefix + if dataset.ray_namespace is None: + dataset.ray_namespace = self.ray_namespace remained_tasksets.append(dataset) self.buffer.explorer_input.eval_tasksets = remained_tasksets @@ -480,12 +497,16 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.trainer_input.experience_buffer.algorithm_type = ( self.algorithm.algorithm_type ) + if self.buffer.trainer_input.experience_buffer.ray_namespace is None: + self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace # set buffer.explorer_output if self.buffer.explorer_output is None: self.buffer.explorer_output = self.buffer.trainer_input.experience_buffer else: self.buffer.explorer_output.algorithm_type = self.algorithm.algorithm_type + if self.buffer.explorer_output.ray_namespace is None: + self.buffer.explorer_output.ray_namespace = self.ray_namespace # check trainer_input.sft_warmup_dataset if ( @@ -497,6 +518,70 @@ def _check_buffer(self) -> None: # noqa: C901 ) if self.buffer.trainer_input.sft_warmup_dataset is not None: self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO + if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None: + self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace + + # check input/output buffers in experience pipelines + if self.data_processor.experience_pipeline is not None: + # collect existing buffers for trinity + input_buffers = {} + output_buffers = {} + # - taskset + if self.buffer.explorer_input.taskset.name: + input_buffers[ + self.buffer.explorer_input.taskset.name + ] = self.buffer.explorer_input.taskset + # - explorer output + if self.buffer.explorer_output and self.buffer.explorer_output.name: + output_buffers[self.buffer.explorer_output.name] = self.buffer.explorer_output + # - trainer input: experience buffer + if ( + self.buffer.trainer_input.experience_buffer + and self.buffer.trainer_input.experience_buffer.name + ): + input_buffers[ + self.buffer.trainer_input.experience_buffer.name + ] = self.buffer.trainer_input.experience_buffer + # - trainer input: sft warmup dataset + if ( + self.buffer.trainer_input.sft_warmup_dataset + and self.buffer.trainer_input.sft_warmup_dataset.name + ): + input_buffers[ + self.buffer.trainer_input.sft_warmup_dataset.name + ] = self.buffer.trainer_input.sft_warmup_dataset + + # when experience pipeline is on, the explorer output and the + # experience buffer of trainer input should be different + if self.buffer.explorer_output == self.buffer.trainer_input.experience_buffer: + raise ValueError( + "The explorer output buffer should be different from the experience buffer of the trainer input " + "when experience pipeline is provided." + ) + + # NOTICE: For now, input/output buffers for data processors should come from output/input buffers of trinity + # the input buffers in experience pipeline should come from the output buffers of trinity + exp_pipeline_input_buffers = self.data_processor.experience_pipeline.input_buffers + synced_input_buffers = [] + for input_buffer in exp_pipeline_input_buffers: + if input_buffer.name not in output_buffers: + raise ValueError( + f"The input buffer {input_buffer.name} of experience pipeline is not found in any output " + f"buffers of trinity." + ) + synced_input_buffers.append(output_buffers[input_buffer.name]) + self.data_processor.experience_pipeline.input_buffers = synced_input_buffers + # the output buffers of trinity should come from the input buffers of trinity + exp_pipeline_output_buffers = self.data_processor.experience_pipeline.output_buffer + if exp_pipeline_output_buffers.name not in input_buffers: + raise ValueError( + f"The output buffer {exp_pipeline_output_buffers.name} of experience pipeline is not found in any " + f"input buffers of trinity." + ) + else: + self.data_processor.experience_pipeline.output_buffer = input_buffers[ + exp_pipeline_output_buffers.name + ] # set read_batch_size / pad_token_id / tokenizer_path self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 9a428131fe..bac4941453 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -103,3 +103,19 @@ class RunningStatus(Enum): RUNNING = "running" WAITING_SYNC = "waiting_sync" STOPPED = "stopped" + + +class DataProcessorPipelineType(Enum): + """Data processor pipeline type.""" + + EXPERIENCE = "experience_pipeline" + TASK = "task_pipeline" + + +class OpType(Enum): + """Operator type for reward shaping.""" + + ADD = "add" + SUB = "sub" + MUL = "mul" + DIV = "div" diff --git a/trinity/common/experience.py b/trinity/common/experience.py index a31b778563..0dcada3f04 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -34,6 +34,18 @@ def __post_init__(self): self.action_mask.shape == self.tokens.shape ), "The provided action_mask must have the same shape as tokens." + # explicit type cast + if not isinstance(self.tokens, Tensor): + self.tokens = Tensor(self.tokens) + if self.logprobs is not None and not isinstance(self.logprobs, Tensor): + self.logprobs = Tensor(self.logprobs) + if self.action_mask is not None and not isinstance(self.action_mask, Tensor): + self.action_mask = Tensor(self.action_mask) + if self.chosen is not None and not isinstance(self.chosen, Tensor): + self.chosen = Tensor(self.chosen) + if self.rejected is not None and not isinstance(self.rejected, Tensor): + self.rejected = Tensor(self.rejected) + def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self) diff --git a/trinity/common/schema.py b/trinity/common/schema.py deleted file mode 100644 index 938ec43bc4..0000000000 --- a/trinity/common/schema.py +++ /dev/null @@ -1,148 +0,0 @@ -# -*- coding: utf-8 -*- -"""Schema for different types of data.""" -from typing import Any, Optional, Type - -from sqlalchemy import JSON, Column, DateTime, Float, Integer, LargeBinary, String, Text -from sqlalchemy.ext.declarative import declarative_base - -from trinity.common.experience import Experience -from trinity.common.models.utils import tokenize_and_mask_messages_hf - -Base: Type = declarative_base() - -# TODO: create db engine and all tables in a factory class - - -class RftDatasetModel(Base): - """SQLAlchemy model for RftDataset.""" - - __tablename__ = "rft_dataset" - - # lineage - id = Column(Integer, primary_key=True, autoincrement=True) - consumed_cnt = Column(Integer, default=0) - last_modified_date = Column(DateTime, nullable=True) - from_id = Column(Integer, nullable=True) - from_model = Column(Text, nullable=True) - from_recipe = Column(Text, nullable=True) - # content - prompt = Column(Text, nullable=True) - response = Column(Text, nullable=True) - solution = Column(Text, nullable=True) - reward = Column(Float, nullable=True) - chosen = Column(Text, nullable=True) - rejected = Column(Text, nullable=True) - label = Column(Text, nullable=True) - # extra info - quality_score = Column(Float, default=0.0) - quality_score_detail = Column(JSON, nullable=True) - difficulty_score = Column(Float, default=0.0) - difficulty_score_detail = Column(JSON, nullable=True) - diversity_score = Column(Float, default=0.0) - diversity_score_detail = Column(JSON, nullable=True) - priority = Column(Float, default=0.0) - # downstream - reward_fn = Column(Text, nullable=True) - workflow = Column(Text, nullable=True) - - def to_dict(self) -> dict: - return {key: val for key, val in self.__dict__.items() if not key.startswith("_")} - - -class TaskModel(Base): - """SQLAlchemy model for Task.""" - - # TODO: Add more fields - - __tablename__ = "task_buffer" - - id = Column(Integer, primary_key=True, autoincrement=True) - task_desc = Column(String, nullable=True) - workflow_type = Column(String, nullable=True) - reward_type = Column(String, nullable=True) - - -class ExperienceModel(Base): - """SQLAlchemy model for Experience.""" - - __tablename__ = "experience_buffer" - - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - prompt = Column(String, nullable=True) - response = Column(String, nullable=True) - reward = Column(Float, nullable=True) - consumed = Column(Integer, default=0) - priority = Column(Float, default=0.0) - - def to_experience(self) -> Experience: - """Load the experience from the database.""" - return Experience.deserialize(self.serialized_exp) - - @staticmethod - def from_experience(experience: Experience): - """Save the experience to database.""" - return ExperienceModel( - serialized_exp=experience.serialize(), - reward=experience.reward, - prompt=experience.prompt_text, - response=experience.response_text, - ) - - -class SFTDataModel(Base): - """SQLAlchemy model for SFT data.""" - - __tablename__ = "sft_data_buffer" - - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - messages = Column(String, nullable=True) - consumed = Column(Integer, default=0) - - def to_experience(self) -> Experience: - """Load the experience from the database.""" - return Experience.deserialize(self.serialized_exp) - - @classmethod - def from_messages( - cls, - messages: list[dict], - tokenizer: Any, - chat_template: Optional[str] = None, - ) -> "SFTDataModel": - """Convert a list of messages into a single instance of SFT data.""" - token_ids, action_mask = tokenize_and_mask_messages_hf( - tokenizer=tokenizer, - messages=messages, - chat_template=chat_template, - ) - exp = Experience( - tokens=token_ids, - prompt_length=0, - action_mask=action_mask, - info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, - ) - return cls( - serialized_exp=exp.serialize(), - messages=messages, - ) - - -class DPODataModel(Base): - """SQLAlchemy model for DPO data.""" - - __tablename__ = "dpo_data_buffer" - - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - chosen = Column(LargeBinary, nullable=True) - rejected = Column(LargeBinary, nullable=True) - consumed = Column(Integer, default=0) - - def to_experience(self) -> Experience: - """Load the experience from the database.""" - exp = Experience.deserialize(self.serialized_exp) - exp.chosen = Experience.deserialize(self.chosen) - exp.rejected = Experience.deserialize(self.rejected) - return exp diff --git a/trinity/data/controllers/active_iterator.py b/trinity/data/controllers/active_iterator.py index 963a1015a9..227ad23b02 100644 --- a/trinity/data/controllers/active_iterator.py +++ b/trinity/data/controllers/active_iterator.py @@ -1,17 +1,24 @@ import os +import threading import traceback +from functools import partial from numbers import Number -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import ray +from data_juicer.utils.constant import Fields -from trinity.common.config import BufferConfig, DataPipelineConfig +from trinity.common.config import BufferConfig, DataPipelineConfig, RewardShapingConfig +from trinity.common.constants import DataProcessorPipelineType, OpType from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS from trinity.data.controllers.task_parser import DataTaskParser from trinity.data.core.dataset import RftDataset from trinity.data.processors.cleaner import DataCleaner from trinity.data.processors.human_annotator import DataHumanAnnotator from trinity.data.processors.synthesizer import DataSynthesizer +from trinity.utils.log import get_logger + +logger = get_logger(__name__) class DataActiveIterator: @@ -23,9 +30,22 @@ def __init__( self, config: DataPipelineConfig, buffer_config: BufferConfig, + pipeline_type: Union[DataProcessorPipelineType, str] = DataProcessorPipelineType.TASK, ): + """ + The initialization method. + + :param config: the data pipeline config. + :param buffer_config: the buffer config. + :param pipeline_type: the type of the activated pipeline. + """ self.config = config self.buffer_config = buffer_config + self.pipeline_type = pipeline_type + if isinstance(self.pipeline_type, str): + self.pipeline_type = DataProcessorPipelineType(pipeline_type) + + # check if the llm agent is required if self.config.agent_model_name is not None and self.config.agent_model_config is not None: # get the api key api_key = os.environ.get("OPENAI_API_KEY") @@ -42,6 +62,8 @@ def __init__( ) else: self.llm_agent = None + + # init task parser self.task_parser = DataTaskParser(config, self.llm_agent) # Priority weights @@ -77,9 +99,10 @@ def __init__( self.updated_op_args["field_names"].append(self.config.format.response_key) # flake8: noqa: C901 - def run(self): + def run(self, thread_event: threading.Event = None): """Run the active iterator.""" # step 1. parse the dj config + logger.info("Parsing the Data-Juicer config...") try: ( dj_config, @@ -91,14 +114,16 @@ def run(self): traceback.print_exc() return 1, "config parsing failed." - # step 2. load data from the input buffers + # step 2. prepare rft-dataset from the input buffers + logger.info("Preparing Rft-Dataset from input buffers...") try: dataset = RftDataset(self.config, self.buffer_config) except Exception: traceback.print_exc() return 2, "RftDataset loading failed." - # step 3. load cleaner + # step 3. load processor + logger.info("Loading data processors...") try: if hit_cleaner: cleaner = DataCleaner( @@ -120,7 +145,13 @@ def run(self): return 3, "DataCleaner loading failed." while True: + # if a stop event is set, stop! + if thread_event and thread_event.is_set(): + logger.info("Stop event is set, stopping the pipeline...") + break + # step 4. load data from the input buffers for the next batch + logger.info("Loading data from input buffers for the next batch...") try: dataset.read_from_buffer() except StopIteration: @@ -130,6 +161,7 @@ def run(self): return 4, "RftDataset loading from buffers failed." # step 5. apply processors to calculate scores of different dimensions + logger.info("Applying data processors to calculate stats...") try: res_dataset = dataset if hit_cleaner: @@ -145,6 +177,7 @@ def run(self): # step 6. calculate the average and final scores, including priority try: if hit_cleaner: + logger.info("Calculating the average and final scores...") scored_dataset = self._group_scores(res_dataset) scored_dataset = self._compute_priority_scores(scored_dataset) else: @@ -153,34 +186,55 @@ def run(self): traceback.print_exc() return 6, "Grouping and computing priority score failed." - # step 7. track lineage if they are changed + # step 7. reward shaping. Only available for experience pipeline and the reward shaping config is set + try: + if ( + self.pipeline_type == DataProcessorPipelineType.EXPERIENCE + and self.config.reward_shaping is not None + and len(self.config.reward_shaping) > 0 + ): + logger.info("Rewarding shaping...") + reshaped_dataset = self._reward_shaping(scored_dataset) + else: + reshaped_dataset = scored_dataset + except Exception: + traceback.print_exc() + return 7, "Reward shaping failed." + + # step 8. track lineage if they are changed try: - res_dataset = scored_dataset + res_dataset = reshaped_dataset except Exception: traceback.print_exc() - return 7, "Tracking lineage failed." + return 8, "Tracking lineage failed." - # step 8 + # step 9, sort the dataset by the computed priority try: if "priority" in res_dataset.data.features: + logger.info("Sorting samples by priority...") res_dataset.sort_by("priority", reverse=True) except Exception: traceback.print_exc() - return 8, "Sorting results by priority failed." + return 9, "Sorting results by priority failed." - # step 9. sort and export the result to the output buffer + # step 10. export the result to the output buffer try: + logger.info("Writing processed data to output buffer...") res_dataset.write_to_buffer() except Exception: traceback.print_exc() - return 9, "Exporting result to output buffer failed." + return 10, "Exporting result to output buffer failed." + + try: + dataset.release_output_buffer() + except Exception: + traceback.print_exc() + return -1, "Releasing output buffer failed." return 0, "success" def _group_scores(self, dataset: RftDataset) -> RftDataset: # for perplexity, normalize them with the max value. - from data_juicer.utils.constant import Fields - stats_min_max = {} for stats in dataset.data.features[Fields.stats]: all_stats = [ @@ -268,6 +322,45 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset: dataset.data = dataset.data.map(self._compute_combined_score) return dataset + def _reward_shaping_single(self, sample, reward_shaping_config: RewardShapingConfig): + tgt_stats = reward_shaping_config.stats_key + op_type = reward_shaping_config.op_type + # if the target stats does not exist, skip this stats and return the original sample + if tgt_stats not in sample[Fields.stats]: + return sample + if op_type == OpType.ADD: + sample[self.config.format.reward_key] += ( + reward_shaping_config.weight * sample[Fields.stats][tgt_stats] + ) + elif op_type == OpType.MUL: + sample[self.config.format.reward_key] *= ( + reward_shaping_config.weight * sample[Fields.stats][tgt_stats] + ) + elif op_type == OpType.SUB: + sample[self.config.format.reward_key] -= ( + reward_shaping_config.weight * sample[Fields.stats][tgt_stats] + ) + elif op_type == OpType.DIV: + sample[self.config.format.reward_key] /= ( + reward_shaping_config.weight * sample[Fields.stats][tgt_stats] + ) + return sample + + def _reward_shaping(self, rft_dataset: RftDataset) -> RftDataset: + dataset = rft_dataset.data + # check if there is a reward column in the dataset. If not, skip! + if self.config.format.reward_key not in dataset.features: + return rft_dataset + # get reward shaping configs + reward_shaping_configs = self.config.reward_shaping + for reward_shaping_config in reward_shaping_configs: + dataset = dataset.map( + partial(self._reward_shaping_single, reward_shaping_config=reward_shaping_config) + ) + + rft_dataset.data = dataset + return rft_dataset + @ray.method(num_returns=1) def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]: """Select a batch of samples for training""" diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index 6b6d126f9b..3ac8e3636f 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -1,13 +1,22 @@ from abc import ABC -from dataclasses import dataclass +from dataclasses import asdict, dataclass, fields, is_dataclass from typing import Any, Dict, List, Optional, Union import networkx as nx from datasets import Dataset, concatenate_datasets from trinity.buffer import get_buffer_reader, get_buffer_writer -from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig +from trinity.common.config import BufferConfig, DataPipelineConfig from trinity.data.core.formatter import BaseDataFormatter +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +def dict_to_dataclass(cls, d): + valid_keys = {f.name for f in fields(cls)} + filtered = {k: v for k, v in d.items() if k in valid_keys} + return cls(**filtered) @dataclass @@ -42,13 +51,18 @@ def __init__( ): self.config = data_pipeline_config self.buffer_config = buffer_config + # init input buffers input_buffer_configs = self.config.input_buffers if len(input_buffer_configs) == 0: raise ValueError("input_buffers is empty in data pipeline config") - self.buffers = [] + self.input_buffers = [] for input_buffer_config in input_buffer_configs: - self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config)) + self.input_buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config)) + # init output buffer + self.output_buffer = get_buffer_writer(self.config.output_buffer, self.buffer_config) + self.data = Dataset.from_list([]) + self.original_dataclass = None self.reward_schema = self._init_reward_schema(reward_schema) self.stats: Dict[str, Any] = {} @@ -71,22 +85,28 @@ def sort_by(self, key: str, reverse: bool = False, top_k: int = -1): def read_from_buffer(self): datasets = [] - for buffer in self.buffers: - datasets.append(Dataset.from_list(buffer.read())) + for buffer in self.input_buffers: + exp_list = buffer.read() + if len(exp_list) > 0 and is_dataclass(exp_list[0]): + exp_list = [asdict(exp) for exp in exp_list] + if self.original_dataclass is None: + self.original_dataclass = exp_list[0].__class__ + datasets.append(Dataset.from_list([exp for exp in exp_list])) self.data = concatenate_datasets(datasets) - - def write_to_buffer( - self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None - ): - if output_storage_config is None: - output_storage_config = self.config.output_buffer - if buffer_config is None: - buffer_config = self.buffer_config - output_buffer = get_buffer_writer(output_storage_config, buffer_config) - output_buffer.write(self.data.to_list()) - output_buffer.release() + logger.info(f"Read {len(self.data)} samples from input buffers") + + def write_to_buffer(self): + if self.original_dataclass is not None: + exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()] + else: + exp_list = self.data.to_list() + self.output_buffer.write(exp_list) + logger.info(f"Wrote {len(self.data)} samples to output buffer") self.data = Dataset.from_list([]) + def release_output_buffer(self): + self.output_buffer.release() + def to_parquet(self, path: str): self.data.to_parquet(path) diff --git a/trinity/data/processors/cleaner.py b/trinity/data/processors/cleaner.py index 10979990b1..e55ee23590 100644 --- a/trinity/data/processors/cleaner.py +++ b/trinity/data/processors/cleaner.py @@ -35,7 +35,7 @@ def __init__( dj_cfg: Optional[Namespace], clean_strategy: str = "iterative", min_size_ratio: PositiveFloat = None, - data_dist: str = "gaussian", + data_dist: Optional[str] = "gaussian", op_weights: dict = None, **kwargs, ): diff --git a/trinity/data/server.py b/trinity/data/server.py index e1f57ba81b..69c941b947 100644 --- a/trinity/data/server.py +++ b/trinity/data/server.py @@ -1,4 +1,8 @@ +import threading +from typing import List + import fire +import ray from flask import Flask, jsonify, request from markupsafe import escape @@ -6,6 +10,8 @@ APP_NAME = "data_processor" +EVNET_POOL: List[threading.Event] = [] + @app.route(f"/{APP_NAME}/", methods=["GET"]) def data_processor(pipeline_type): @@ -15,6 +21,10 @@ def data_processor(pipeline_type): config_path = request.args.get("configPath") pipeline_type = escape(pipeline_type) config = load_config(config_path) + config.check_and_update() + + # init ray + ray.init(namespace=config.ray_namespace, ignore_reinit_error=True) pipeline_config = getattr(config.data_processor, pipeline_type) if pipeline_config is None: @@ -33,9 +43,34 @@ def data_processor(pipeline_type): } ) - iterator = DataActiveIterator(pipeline_config, config.buffer) - ret, msg = iterator.run() - return jsonify({"return_code": ret, "message": msg}) + if pipeline_type == "task_pipeline": + # must be sync + iterator = DataActiveIterator(pipeline_config, config.buffer, pipeline_type=pipeline_type) + ret, msg = iterator.run() + return jsonify({"return_code": ret, "message": msg}) + elif pipeline_type == "experience_pipeline": + # must be async + iterator = DataActiveIterator(pipeline_config, config.buffer, pipeline_type=pipeline_type) + # add an event + event = threading.Event() + thread = threading.Thread(target=iterator.run, args=(event,)) + thread.start() + # add this event to the event pool + EVNET_POOL.append(event) + return jsonify({"return_code": 0, "message": "Experience pipeline starts successfully."}) + + +@app.route(f"/{APP_NAME}/stop_all", methods=["GET"]) +def stop_all(): + try: + for event in EVNET_POOL: + event.set() + except Exception: + import traceback + + traceback.print_exc() + return jsonify({"return_code": 1, "message": traceback.format_exc()}) + return jsonify({"return_code": 0, "message": "All data pipelines are stopped."}) def main(port=5005): diff --git a/trinity/data/utils.py b/trinity/data/utils.py new file mode 100644 index 0000000000..e3c74c05ed --- /dev/null +++ b/trinity/data/utils.py @@ -0,0 +1,72 @@ +from trinity.common.config import DataPipelineConfig +from trinity.common.constants import DataProcessorPipelineType +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +def activate_data_processor(data_processor_url: str, config_path: str): + """Check whether to activate data module and preprocess datasets.""" + from trinity.cli.client import request + + logger.info(f"Activating data module of {data_processor_url}...") + res = request( + url=data_processor_url, + configPath=config_path, + ) + if res["return_code"] != 0: + logger.error(f"Failed to activate data module: {res['return_msg']}.") + return + + +def stop_data_processor(base_data_processor_url: str): + """Stop all pipelines in the data processor""" + from trinity.cli.client import request + + logger.info(f"Stopping all pipelines in {base_data_processor_url}...") + res = request(url=f"{base_data_processor_url}/stop_all") + if res["return_code"] != 0: + logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.") + return + + +def validate_data_pipeline( + data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType +): + """ + Check if the data pipeline is valid. The config should: + 1. Non-empty input buffer + 2. Different input/output buffers + + :param data_pipeline_config: the input data pipeline to be validated. + :param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType + """ + input_buffers = data_pipeline_config.input_buffers + output_buffer = data_pipeline_config.output_buffer + # common checks + # check if the input buffer list is empty + if len(input_buffers) == 0: + logger.warning("Empty input buffers in the data pipeline. Won't activate it.") + return False + # check if the input and output buffers are different + input_buffer_names = [buffer.name for buffer in input_buffers] + if output_buffer.name in input_buffer_names: + logger.warning("Output buffer exists in input buffers. Won't activate it.") + return False + if pipeline_type == DataProcessorPipelineType.TASK: + # task pipeline specific + # "raw" field should be True for task pipeline because the data source must be raw data files + for buffer in input_buffers: + if not buffer.raw: + logger.warning( + 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' + ) + return False + elif pipeline_type == DataProcessorPipelineType.EXPERIENCE: + # experience pipeline specific + # No special items need to be checked. + pass + else: + logger.warning(f"Invalid pipeline type: {pipeline_type}..") + return False + return True