diff --git a/.gitignore b/.gitignore index 646848ade7..7ab517ff20 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,7 @@ ENV/ logs/ # data-juicer +tmp/ outputs/ # agentscope runs/ diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 6558efcdd9..27b5fb26bf 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -1,80 +1,97 @@ # Data Processing -## Example: reasoning task +## Example: Data Processor for Task Pipeline -In this example, you will learn how to apply the data module of Trinity-RFT to prepare the dataset before exploring and training. This example takes GSM-8K dataset as the example dataset to figure out: +In this example, you will learn how to apply the data processor of Trinity-RFT to prepare and prioritize the dataset before task exploring and training. This example takes GSM-8K dataset as the example dataset to figure out: -1. how to prepare the data module -2. how to configure the data module -3. what the data module can do +1. how to prepare the data processor +2. how to configure the data processor +3. what the data processor can do -Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md), and you need to install [postgresql](https://www.postgresql.org/docs/current/tutorial-install.html) as well. +Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md). ### Data Preparation -#### Prepare the Data Module +#### Prepare the Data Processor -As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. +As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. ```shell -# prepare split environments, including the one of data module +# prepare split environments, including the one of data processor python scripts/install.py # start all split servers python scripts/start_servers.py ``` -### Configure the Data Module +### Configure the Data Processor -Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file. +Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file. In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: ```yaml data_processor: - # basic info - source_data_path: /PATH/TO/GSM8K/ - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' ``` -Here you can set the basic information for the GSM-8K dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training: +Here you can set the basic buffers for the GSM-8K dataset input and output and some other items about downstream dataset loading for exploring and training: -+ `source_data_path`: the path to the raw dataset. -+ `load_kwargs`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library. -+ `format`: some dataset format config items, which are used to map original data field names to unified ones. -+ `db_url`: the URL of the postgresql database to store the result dataset. ++ `data_processor_url`: the URL of the data processor service, which is started in the previous step. ++ `task_pipeline`: the configs for the task pipeline. Task pipeline is used to process the raw dataset. It consists of several inner configs: + + `input_buffers`: the input buffers for the task pipeline. We usually load from raw dataset files in this pipeline, thus we need to the dataset `path` and set the `storage_type` to "file" and set `raw` to True. It allows multiple input buffers. We can name each buffer with the `name` field. + + `output_buffer`: the output buffer for the task pipeline. We usually store the processed dataset in files as well, thus we need to set the `storage_type` to "file". + + `format`: some dataset format config items, which are used to map original data field names to unified ones. -In addition, there are several config items related to the data active iterator, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer. +In addition, there are several config items related to the data active iterator in `task_pipeline` part, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer. #### Not familiar with Data-Juicer -If you are not familiar with Data-Juicer, the data module provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example: +If you are not familiar with Data-Juicer, the data processor provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example: ```yaml data_processor: - # basic info - source_data_path: /PATH/TO/GSM8K/ - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - - #### new part about data active iterator - dj_process_desc: 'Please compute difficulty scores for these math questions.' - agent_model_name: 'qwen-max' - agent_model_config: - config_name: 'my-qwen-instruction' - model_type: 'dashscope_chat' - model_name: 'qwen2.5-72b-instruct' - clean_strategy: 'iterative' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + + #### new part about data active iterator + dj_process_desc: 'Please compute difficulty scores for these math questions.' + agent_model_name: 'qwen-max' + agent_model_config: + config_name: 'my-qwen-instruction' + model_type: 'dashscope_chat' + model_name: 'qwen2.5-72b-instruct' + clean_strategy: 'iterative' ``` You can write your demand description in config item `dj_process_desc`, and set the model name and configs used for the agent in config items `agent_model_name` and `agent_model_config`. Here we use Qwen2.5-72b-Instruct as our recipe managing agent. And you can set the `clean_strategy` to 'iterative' to get a better dataset. @@ -99,19 +116,27 @@ After preparing the Data-Juicer data processing recipe, you can set the `dj_conf ```yaml data_processor: - # basic info - source_data_path: /PATH/TO/GSM8K/ - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - - #### new part about data active iterator - dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml' - clean_strategy: 'iterative' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + + #### new part about data active iterator + dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml' + clean_strategy: 'iterative' ``` And you can set the `clean_strategy` to 'iterative' to get a better dataset. @@ -123,7 +148,7 @@ All config items in the `data` section can be found [here](trinity_configs.md). ```{note} -Only when one of `dj_process_desc` and `dj_config_path` is provided, the data module and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly. +Only when one of `xxx_pipeline` is provided, and one of `dj_process_desc` and `dj_config_path` in the pipeline config is provided, the data processor and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly. ``` ### Exploring & Training @@ -140,49 +165,54 @@ ray start --address= trinity run --config ``` -If you follow the steps above, Trinity-RFT will send a request to the data module server, the data active iterator will be activated and compute difficulty scores for each sample in the raw dataset. After that, the data module server stores the result dataset into the database, when exploring begins, it will load the prepared dataset and continue the downstream steps. +If you follow the steps above, Trinity-RFT will send a request to the data processor server, the data active iterator will be activated, compute difficulty scores for each sample in the raw dataset, and rank the dataset according to difficulty scores. After that, the data processor server stores the result dataset into the output buffer, when exploring begins, it will load the prepared dataset and continue the downstream steps. - - -## Example: human in the loop +## Example: Human in the Loop Sometimes, you might need to involve human feedbacks for some raw data. In this example, you will learn how to annotate raw data to get a better dataset before training. This example takes an example Q&A dataset and tries to select the chosen and rejected ones for DPO method. -Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, install postgresql, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source. +Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source. ### Data Preparation -#### Prepare the Data Module +#### Prepare the Data Processor -As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. +As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. ```shell -# prepare split environments, including the one of data module +# prepare split environments, including the one of data processor python scripts/install.py # start all split servers python scripts/start_servers.py ``` -### Configure the Data Module +### Configure the Data Processor -Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file. +Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file. -In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: +In this example, assume that you need to select the chosen and rejected responses for DPO method. So you can set these config items like the following example: ```yaml data_processor: - # basic info - source_data_path: 'tests/test_data/test_human_annotator' - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'prompt' - chosen_key: 'chosen' - rejected_key: 'rejected' - #### new part about data active iterator - dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: 'tests/test_data/test_human_annotator' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl' + storage_type: 'file' + format: # set the field mappings + prompt_key: 'prompt' + chosen_key: 'chosen' + rejected_key: 'rejected' + #### new part about data active iterator + dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml' ``` Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above. @@ -223,7 +253,7 @@ You can set more config items for this OP (e.g. notification when annotation is ### Start Running -When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server. +When you start running with the RFT config, the data processor will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server. ![](../../assets/data-projects.png) diff --git a/environments/data.yaml b/environments/data.yaml index 6acdf04dc9..d43ece076b 100644 --- a/environments/data.yaml +++ b/environments/data.yaml @@ -6,10 +6,5 @@ dependencies: - pip: - py-data-juicer - agentscope - - flask - - omegaconf - - sqlalchemy - - psycopg2 - - networkx - transformers - "-e ..[dev]" diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 2a87ef288b..0763586457 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -4,19 +4,6 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: grpo repeat_times: 8 -data_processor: - # basic info - source_data_path: 'openai/gsm8k' - # data active iterator related - dj_process_desc: 'Please compute difficulty scores for these math questions.' - agent_model_name: 'qwen-max' - agent_model_config: - config_name: 'my-qwen-instruction' - model_type: 'dashscope_chat' - model_name: 'qwen2.5-72b-instruct' - clean_strategy: 'iterative' - # db related - db_url: '' model: model_path: /PATH/TO/MODEL/ @@ -41,9 +28,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - n: 8 temperature: 1.0 - logprobs: 0 eval_tasksets: - name: gsm8k-eval storage_type: file diff --git a/examples/grpo_gsm8k_task_pipeline/README.md b/examples/grpo_gsm8k_task_pipeline/README.md new file mode 100644 index 0000000000..ead6a56185 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/README.md @@ -0,0 +1,7 @@ +# GRPO on GSM8K dataset with Task Pipeline + +This example shows the usage of GRPO on the GSM8K dataset, with a task pipeline to prioritize the raw dataset before training. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml new file mode 100644 index 0000000000..36514e0e01 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml @@ -0,0 +1,95 @@ +project: "Trinity-RFT-gsm8k-task-pipeline" +name: "qwen2.5-1.5B-gsm8k-task-pipeline" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +data_processor: + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: 'openai/gsm8k' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl' + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + # data active iterator related + dj_process_desc: 'Please compute difficulty scores for these math questions.' + agent_model_name: 'qwen-max' + agent_model_config: + config_name: 'my-qwen-instruction' + model_type: 'dashscope_chat' + model_name: 'qwen2.5-72b-instruct' + clean_strategy: 'iterative' + +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 256 + max_response_tokens: 1024 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: './outputs/task_pipeline_output/' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml new file mode 100644 index 0000000000..fc44fdad94 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml @@ -0,0 +1,50 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/pyproject.toml b/pyproject.toml index c6917217ad..6ba60afab3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "requests", "tensorboard", "openai", + "jsonlines", ] [project.scripts] diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 363a4939ad..e53669a850 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -9,12 +9,54 @@ get_unittest_dataset_config, ) from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer +from trinity.buffer.reader.file_reader import RawDataReader from trinity.buffer.utils import default_storage_path +from trinity.buffer.writer.file_writer import JSONWriter from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class TestFileBuffer(unittest.TestCase): + temp_output_path = "tmp/test_file_buffer/" + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.makedirs(cls.temp_output_path, exist_ok=True) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if os.path.exists(cls.temp_output_path): + os.system(f"rm -rf {cls.temp_output_path}") + + def test_file_buffer(self): + meta = StorageConfig( + name="test_buffer", + path=os.path.join(self.temp_output_path, "buffer.jsonl"), + storage_type=StorageType.FILE, + raw=True, + ) + data = [ + {"key1": 1, "key2": 2}, + {"key1": 3, "key2": 4}, + {"key1": 5, "key2": 6}, + {"key1": 7, "key2": 8}, + ] + + # test writer + writer = JSONWriter(meta, None) + writer.write(data) + writer.finish() + + # test reader + meta.path = self.temp_output_path + reader = RawDataReader(meta, None) + loaded_data = reader.read() + self.assertEqual(len(loaded_data), 4) + self.assertEqual(loaded_data, data) + self.assertRaises(StopIteration, reader.read) + def test_file_reader(self): """Test file reader.""" reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) diff --git a/tests/data/controllers/task_parser_test.py b/tests/data/controllers/task_parser_test.py index 542c491f41..af36f8777a 100644 --- a/tests/data/controllers/task_parser_test.py +++ b/tests/data/controllers/task_parser_test.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- """Test cases for data task parser.""" +import os import unittest import agentscope from agentscope.models import DashScopeChatWrapper from loguru import logger -from trinity.common.config import Config +from trinity.common.config import DataPipelineConfig from trinity.data.controllers.task_parser import DataTaskParser @@ -16,7 +17,7 @@ class TestTaskParser(unittest.TestCase): def setUp(self) -> None: print("setup", flush=True) - api_key = "your_dashscope_key" + api_key = os.environ.get("OPENAI_API_KEY", None) agentscope.init( model_configs=[ @@ -43,25 +44,20 @@ def _run_test(self, rft_config, return_none=False): logger.info("None dj config.") else: self.assertIsNotNone(dj_config) - op_weights = {} - for op in dj_config.process: - op_name = list(op.keys())[0] - op_weights[op_name] = op[op_name]["op_weight"] - logger.info(op_weights) def test_instruction1(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Please recommend a data filtering strategy for me." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Please recommend a data filtering strategy for me." self._run_test(rft_config) def test_instruction2(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Do nothing." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Do nothing." self._run_test(rft_config, return_none=True) def test_instruction3(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Remove samples with repeat contents." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Remove samples with repeat contents." self._run_test(rft_config) diff --git a/tests/data/core/dataset_test.py b/tests/data/core/dataset_test.py index be6e765fbd..76758e84d6 100644 --- a/tests/data/core/dataset_test.py +++ b/tests/data/core/dataset_test.py @@ -3,10 +3,7 @@ import os import unittest -from trinity.common.config import DataProcessorConfig, FormatConfig -from trinity.common.rewards import AccuracyReward -from trinity.common.task import TaskSet -from trinity.common.workflows import MathWorkflow, SimpleWorkflow +from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig from trinity.data.core.dataset import RewardSchema, RftDataset from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter @@ -15,28 +12,38 @@ class TestRftDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_pipeline_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="solution", solution_key="solution", ), ) - self.data_config_sample_level_setting = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10_with_rewfn_workflow", - ), + self.data_pipeline_config_sample_level_setting = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10_with_rewfn_workflow", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="solution", @@ -47,13 +54,19 @@ def setUp(self) -> None: ) def test_rft_dataset_init(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset( + data_pipeline_config=self.data_pipeline_config, reward_schema="default" + ) + dataset.read_from_buffer() self.assertEqual(len(dataset), 10) self.assertIsInstance(dataset.reward_schema, RewardSchema) def test_format_dataset(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset( + data_pipeline_config=self.data_pipeline_config, reward_schema="default" + ) + dataset.read_from_buffer() original_data = dataset.data # no formatter dataset.format(formatters=[]) @@ -62,56 +75,12 @@ def test_format_dataset(self): # apply formatters dataset.format( formatters=[ - BoxedMathAnswerFormatter(config=self.data_config.format), - RLHFFormatter(config=self.data_config.format), + BoxedMathAnswerFormatter(config=self.data_pipeline_config.format), + RLHFFormatter(config=self.data_pipeline_config.format), ] ) self.assertNotEqual(dataset.data, original_data) - def test_to_taskset(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") - taskset = dataset.to_taskset() - self.assertIsInstance(taskset, TaskSet) - self.assertEqual(len(taskset), 10) - self.assertIsNone(taskset.reward_fn) - self.assertIsNone(taskset.workflow) - self.assertEqual(taskset._index, 0) - - def test_to_taskset_with_global_settings(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") - taskset = dataset.to_taskset( - reward_fn=AccuracyReward, - workflow=SimpleWorkflow, - ) - self.assertIsInstance(taskset, TaskSet) - self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertEqual(taskset.reward_fn, AccuracyReward) - - def test_to_taskset_with_sample_level_settings(self): - dataset = RftDataset( - data_config=self.data_config_sample_level_setting, reward_schema="default" - ) - taskset = dataset.to_taskset() - self.assertIsInstance(taskset, TaskSet) - for task in taskset.tasks: - self.assertEqual(task.workflow, MathWorkflow) - self.assertEqual(task.reward_fn, AccuracyReward) - - def test_to_taskset_with_both_settings(self): - dataset = RftDataset( - data_config=self.data_config_sample_level_setting, reward_schema="default" - ) - taskset = dataset.to_taskset( - reward_fn=AccuracyReward, - workflow=SimpleWorkflow, - ) - self.assertIsInstance(taskset, TaskSet) - for task in taskset.tasks: - self.assertEqual(task.workflow, MathWorkflow) - self.assertEqual(task.reward_fn, AccuracyReward) - self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertEqual(taskset.reward_fn, AccuracyReward) - if __name__ == "__main__": unittest.main() diff --git a/tests/data/core/formatter_test.py b/tests/data/core/formatter_test.py index 363c736ed9..dbb73ed971 100644 --- a/tests/data/core/formatter_test.py +++ b/tests/data/core/formatter_test.py @@ -3,7 +3,7 @@ import os import unittest -from trinity.common.config import DataProcessorConfig, FormatConfig +from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig from trinity.data.core.dataset import RftDataset from trinity.data.core.formatter import ( BoxedMathAnswerFormatter, @@ -18,14 +18,19 @@ class TestBoxedMathDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", @@ -43,12 +48,13 @@ def test_init(self): self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ") # test for default configs self.assertEqual(formatter.config.reward_key, "") - self.assertEqual(formatter.config.chosen_key, "") - self.assertEqual(formatter.config.rejected_key, "") + self.assertEqual(formatter.config.chosen_key, "chosen") + self.assertEqual(formatter.config.rejected_key, "rejected") self.assertEqual(formatter.config.label_key, "") def test_transform(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset(data_pipeline_config=self.data_config, reward_schema="default") + dataset.read_from_buffer() formatter = BoxedMathAnswerFormatter(config=self.data_config.format) self.assertNotIn(formatter.config.response_key, dataset.data.column_names) dataset.format(formatter) @@ -59,14 +65,19 @@ class TestRLHFFormatter(unittest.TestCase): """Test cases for RLHFFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", chat_template="User: {}\nAssistant: ", @@ -107,14 +118,19 @@ class TestRewardFormatter(unittest.TestCase): """Test cases for RewardFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", chosen_key="chosen", @@ -164,14 +180,19 @@ class TestSFTFormatter(unittest.TestCase): """Test cases for SFTFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", @@ -217,14 +238,19 @@ class TestComposedFormatter(unittest.TestCase): """Test cases for ComposedFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", diff --git a/tests/data/processor/cleaner_test.py b/tests/data/processor/cleaner_test.py index d21a6960c5..ef2aa13d20 100644 --- a/tests/data/processor/cleaner_test.py +++ b/tests/data/processor/cleaner_test.py @@ -15,7 +15,7 @@ def setUp(self) -> None: print("setup", flush=True) self.rft_config = load_config("./tests/test_configs/cleaner_test_rft_cfg.yaml") - print(self.rft_config) + # print(self.rft_config) self.ds_list = [ {"text": "Today is"}, {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, @@ -25,95 +25,67 @@ def setUp(self) -> None: ] def _run_test(self, tgt_list, weight=1, data_dist="gaussian"): - task_parser = DataTaskParser(self.rft_config) + task_parser = DataTaskParser(self.rft_config.data_processor.task_pipeline) dj_config, _, _, _ = task_parser.parse_to_dj_config() + op_weights = {} for op_config in dj_config.process: - _, op_args = list(op_config.items())[0] - op_args["op_weight"] = weight + op_name, _ = list(op_config.items())[0] + op_weights[op_name] = weight cleaner = DataCleaner( dj_config, clean_strategy="iterative", - min_size_ratio=self.rft_config.data.min_size_ratio, + min_size_ratio=self.rft_config.data_processor.task_pipeline.min_size_ratio, data_dist=data_dist, + op_weights=op_weights, ) - dataset = RftDataset(self.rft_config.data) + dataset = RftDataset(self.rft_config.data_processor.task_pipeline) + dataset.read_from_buffer() dataset = cleaner.process([dataset]) - res_list = dataset.to_list() + res_list = dataset.data.select_columns("text").to_list() + print(res_list) self.assertEqual(res_list, tgt_list) self.assertNotIn("clean_email_mapper", cleaner.dj_cfg.process) def test_dj_executor(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - { - "text": "中文也是一个字算一个长度", - "__dj__stats__": {"text_len": 12}, - }, + {"text": "Today is"}, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, + {"text": "中文也是一个字算一个长度"}, ] - self.rft_config.data.min_size_ratio = None + self.rft_config.data_processor.task_pipeline.min_size_ratio = None self._run_test(tgt_list) def test_iterative_clean(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, ] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list) def test_weight(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - { - "text": "中文也是一个字算一个长度", - "__dj__stats__": {"text_len": 12}, - }, + {"text": "Today is"}, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, ] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list, weight=0.5) def test_uniform_dist(self): - tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - ] + tgt_list = [] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list, data_dist="uniform") diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml index 3b105e1f66..3e6008b7cf 100644 --- a/tests/test_configs/active_iterator_test_cfg.yaml +++ b/tests/test_configs/active_iterator_test_cfg.yaml @@ -1,13 +1,18 @@ data_processor: # basic info - source_data_path: 'tests/test_data/test_10/' - load_kwargs: - split: 'train' - format: - prompt_key: 'problem' - response_key: 'solution' - # cleaner related - dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' - clean_strategy: 'iterative' - # db related - db_url: 'postgresql://{username}@localhost:5432/{db_name}' + task_pipeline: + input_buffers: + - name: 'raw_input' + path: 'tests/test_data/test_10/' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/processed.jsonl' + storage_type: 'file' + format: + prompt_key: 'problem' + response_key: 'solution' + # cleaner related + dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' + clean_strategy: 'iterative' diff --git a/tests/test_configs/active_iterator_test_dj_cfg.yaml b/tests/test_configs/active_iterator_test_dj_cfg.yaml index f7f848e338..367709968f 100644 --- a/tests/test_configs/active_iterator_test_dj_cfg.yaml +++ b/tests/test_configs/active_iterator_test_dj_cfg.yaml @@ -1,7 +1,5 @@ project_name: 'demo-process' -export_path: './outputs/demo-process/demo-processed.jsonl' - text_keys: 'solution' process: diff --git a/tests/test_configs/cleaner_test_dj_cfg.yaml b/tests/test_configs/cleaner_test_dj_cfg.yaml index 9e2da88d64..cf11488963 100644 --- a/tests/test_configs/cleaner_test_dj_cfg.yaml +++ b/tests/test_configs/cleaner_test_dj_cfg.yaml @@ -3,7 +3,5 @@ project_name: 'demo-process' export_path: './outputs/demo-process/demo-processed.jsonl' process: - - text_length_filter: - min_len: 10 - max_len: 50 + - alphanumeric_filter: - clean_email_mapper: diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml index 7f8581c0ef..c78e3a1ac8 100644 --- a/tests/test_configs/cleaner_test_rft_cfg.yaml +++ b/tests/test_configs/cleaner_test_rft_cfg.yaml @@ -1,5 +1,7 @@ data_processor: - source_data_path: './tests/test_data/test_cleaner' - load_kwargs: {"split": "train"} - dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' - clean_strategy: 'iterative' + task_pipeline: + input_buffers: + - path: './tests/test_data/test_cleaner' + raw: true + dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' + clean_strategy: 'iterative' diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml index 79d8b8108b..b20f015182 100644 --- a/tests/test_configs/human_annotator_test_rft_cfg.yaml +++ b/tests/test_configs/human_annotator_test_rft_cfg.yaml @@ -1,10 +1,10 @@ data_processor: - source_data_path: './tests/test_data/test_human_annotator' - load_kwargs: {"split": "train"} - dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' - format: - prompt_key: 'prompt' - chosen_key: 'chosen' - rejected_key: 'rejected' - # db related - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + task_pipeline: + input_buffers: + - path: './tests/test_data/test_human_annotator' + raw: true + dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' + format: + prompt_key: 'prompt' + chosen_key: 'chosen' + rejected_key: 'rejected' diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 73bda50d5b..060ed05b9e 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -42,7 +42,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig from trinity.buffer.reader.file_reader import FILE_READERS algorithm_type = storage_config.algorithm_type - if algorithm_type is not None: + if storage_config.raw: + file_read_type = "raw" + elif algorithm_type is not None: file_read_type = algorithm_type else: file_read_type = "rollout" diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 63de366db6..b7cf06b2b5 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -142,6 +142,8 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: raise ValueError( f"File path must end with '.json' or '.jsonl', got {storage_config.path}" ) + path_dir = os.path.dirname(storage_config.path) + os.makedirs(path_dir, exist_ok=True) self.file = open(storage_config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 0dd9aef75e..507bfb7c82 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -67,7 +67,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split), + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -144,7 +144,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split), + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -216,7 +216,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.epoch = 0 datasets.disable_caching() self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split), + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, offset=self.meta.index, ) @@ -261,3 +261,23 @@ def read( ) tasks.append(task) return tasks + + +@FILE_READERS.register_module("raw") +class RawDataReader(BufferReader): + def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]): + self.returned = False + self.dataset = load_dataset( + meta.path, name=meta.subset_name, split=meta.split, trust_remote_code=True + ) + + def __len__(self): + return len(self.dataset) + + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + if self.returned: + raise StopIteration + self.returned = True + return self.dataset.to_list() diff --git a/trinity/cli/client.py b/trinity/cli/client.py index 311de1b9d8..cc3318b570 100644 --- a/trinity/cli/client.py +++ b/trinity/cli/client.py @@ -31,12 +31,12 @@ def request(url, **kwargs): if __name__ == "__main__": # --- only for local testing - LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow" + LOCAL_DATA_PROCESSOR_SERVER_URL = "http://127.0.0.1:5005/data_processor" LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft" # --- only for local testing res = request( - url=LOCAL_DATA_WORKFLOW_SERVER_URL, + url=LOCAL_DATA_PROCESSOR_SERVER_URL, configPath="examples/grpo_gsm8k/gsm8k.yaml", ) if res: diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index a63b06a36d..3ea4f0486f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -8,7 +8,7 @@ import ray -from trinity.common.config import Config, load_config +from trinity.common.config import Config, DataPipelineConfig, load_config from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer @@ -112,13 +112,13 @@ def both(config: Config) -> None: trainer.shutdown.remote() -def activate_data_module(data_workflow_url: str, config_path: str): +def activate_data_module(data_processor_url: str, config_path: str): """Check whether to activate data module and preprocess datasets.""" from trinity.cli.client import request - logger.info("Activating data module...") + logger.info(f"Activating data module of {data_processor_url}...") res = request( - url=data_workflow_url, + url=data_processor_url, configPath=config_path, ) if res["return_code"] != 0: @@ -126,17 +126,71 @@ def activate_data_module(data_workflow_url: str, config_path: str): return +def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str): + """ + Check if the data pipeline is valid. The config should: + 1. Non-empty input buffer + 2. Different input/output buffers + + :param data_pipeline_config: the input data pipeline to be validated. + :param pipeline_type: the type of pipeline, should be one of ["task", "experience"] + """ + input_buffers = data_pipeline_config.input_buffers + output_buffer = data_pipeline_config.output_buffer + # common checks + # check if the input buffer list is empty + if len(input_buffers) == 0: + logger.warning("Empty input buffers in the data pipeline. Won't activate it.") + return False + # check if the input and output buffers are different + input_buffer_names = [buffer.name for buffer in input_buffers] + if output_buffer.name in input_buffer_names: + logger.warning("Output buffer exists in input buffers. Won't activate it.") + return False + if pipeline_type == "task": + # task pipeline specific + # "raw" field should be True for task pipeline because the data source must be raw data files + for buffer in input_buffers: + if not buffer.raw: + logger.warning( + 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' + ) + return False + elif pipeline_type == "experience": + # experience pipeline specific + raise NotImplementedError("experience_pipeline is not implemented yet.") + else: + logger.warning( + f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].' + ) + return False + return True + + def run(config_path: str, dlc: bool = False, plugin_dir: str = None): load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) - # try to activate data module + # try to activate task pipeline for raw data data_processor_config = config.data_processor - if data_processor_config.data_workflow_url and ( - data_processor_config.dj_config_path or data_processor_config.dj_process_desc + if ( + data_processor_config.data_processor_url + and data_processor_config.task_pipeline + and validate_data_pipeline(data_processor_config.task_pipeline, "task") ): - activate_data_module(data_processor_config.data_workflow_url, config_path) + activate_data_module( + f"{data_processor_config.data_processor_url}/task_pipeline", config_path + ) + # try to activate experience pipeline for experiences + if ( + data_processor_config.data_processor_url + and data_processor_config.experience_pipeline + and validate_data_pipeline(data_processor_config.experience_pipeline, "experience") + ): + activate_data_module( + f"{data_processor_config.data_processor_url}/experience_pipeline", config_path + ) ray_namespace = f"{config.project}-{config.name}" if dlc: from trinity.utils.dlc_utils import setup_ray_cluster diff --git a/trinity/common/config.py b/trinity/common/config.py index 1409fa33f3..f4480da311 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -70,6 +70,9 @@ class StorageConfig: storage_type: StorageType = StorageType.FILE path: Optional[str] = None + # only available for StorageType.FILE. When requiring data processing on raw data, set the raw to True. + raw: bool = False + # used for StorageType.FILE split: str = "train" subset_name: Optional[str] = None @@ -99,16 +102,17 @@ class StorageConfig: @dataclass -class DataProcessorConfig: - """Data-Juicer config""" +class DataPipelineConfig: + """Config for data pipeline.""" - data_workflow_url: Optional[str] = None + # I/O buffer + input_buffers: List[StorageConfig] = field(default_factory=list) + output_buffer: StorageConfig = field(default_factory=StorageConfig) - source_data_path: str = "" + # data format format: FormatConfig = field(default_factory=FormatConfig) # data active iterator related - load_kwargs: Dict[str, Any] = field(default_factory=dict) dj_config_path: Optional[str] = None # The path to Data-Juicer config file. dj_process_desc: Optional[ str @@ -121,10 +125,18 @@ class DataProcessorConfig: priority_weights: Optional[Dict[str, float]] = None data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"] - # dataset database related - db_url: str = "" - max_retry_times: int = 3 - max_retry_interval: int = 1 + +@dataclass +class DataProcessorConfig: + """Data-Juicer config""" + + data_processor_url: Optional[str] = None + + # support two types of data pipelines for now + # 1. For task. Data preprocessing from raw dataset to the task set + task_pipeline: Optional[DataPipelineConfig] = None + # 2. For experience. Data processing for rollouts + experience_pipeline: Optional[DataPipelineConfig] = None @dataclass diff --git a/trinity/data/controllers/active_iterator.py b/trinity/data/controllers/active_iterator.py index 40da73384b..963a1015a9 100644 --- a/trinity/data/controllers/active_iterator.py +++ b/trinity/data/controllers/active_iterator.py @@ -1,14 +1,14 @@ import os import traceback +from numbers import Number from typing import Any, Dict, List import ray -from trinity.common.config import Config +from trinity.common.config import BufferConfig, DataPipelineConfig from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS from trinity.data.controllers.task_parser import DataTaskParser from trinity.data.core.dataset import RftDataset -from trinity.data.core.dataset_db import RftDatasetDB from trinity.data.processors.cleaner import DataCleaner from trinity.data.processors.human_annotator import DataHumanAnnotator from trinity.data.processors.synthesizer import DataSynthesizer @@ -21,42 +21,39 @@ class DataActiveIterator: def __init__( self, - config: Config, + config: DataPipelineConfig, + buffer_config: BufferConfig, ): self.config = config - self.data_config = config.data - if ( - self.data_config.agent_model_name is not None - and self.data_config.agent_model_config is not None - ): + self.buffer_config = buffer_config + if self.config.agent_model_name is not None and self.config.agent_model_config is not None: # get the api key api_key = os.environ.get("OPENAI_API_KEY") # initialize the agent import agentscope from agentscope.models import DashScopeChatWrapper - agentscope.init(model_configs=[self.data_config.agent_model_config]) + agentscope.init(model_configs=[self.config.agent_model_config]) self.llm_agent = DashScopeChatWrapper( config_name="_", - model_name=self.data_config.agent_model_name, + model_name=self.config.agent_model_name, api_key=api_key, stream=False, ) else: self.llm_agent = None self.task_parser = DataTaskParser(config, self.llm_agent) - self.dsdb = RftDatasetDB(self.data_config) # Priority weights # larger positive values means larger scores --> higher priority # smaller negative values means lower scores --> higher priority - self.priority_weights = self.data_config.priority_weights or { + self.priority_weights = self.config.priority_weights or { "difficulty": -0.7, "diversity": 0.8, "usage_frequency": -0.5, "quality": 1.0, } - self.min_priority_score = self.data_config.min_priority_score + self.min_priority_score = self.config.min_priority_score # Statistics tracking self.state = {"iterations": 0, "samples_selected": 0, "avg_priority_score": 0.0} @@ -67,17 +64,17 @@ def __init__( # 2. input_keys: [prompt_key, response_key] if they are available # 3. field_names: [prompt_key, response_key] if they are available self.updated_op_args = { - "text_key": self.data_config.format.prompt_key, + "text_key": self.config.format.prompt_key, "input_keys": [ - self.data_config.format.prompt_key, + self.config.format.prompt_key, ], "field_names": [ - self.data_config.format.prompt_key, + self.config.format.prompt_key, ], } - if self.data_config.format.response_key != "": - self.updated_op_args["input_keys"].append(self.data_config.format.response_key) - self.updated_op_args["field_names"].append(self.data_config.format.response_key) + if self.config.format.response_key != "": + self.updated_op_args["input_keys"].append(self.config.format.response_key) + self.updated_op_args["field_names"].append(self.config.format.response_key) # flake8: noqa: C901 def run(self): @@ -94,9 +91,9 @@ def run(self): traceback.print_exc() return 1, "config parsing failed." - # step 2. load dataset + # step 2. load data from the input buffers try: - dataset = RftDataset(self.data_config) + dataset = RftDataset(self.config, self.buffer_config) except Exception: traceback.print_exc() return 2, "RftDataset loading failed." @@ -106,9 +103,9 @@ def run(self): if hit_cleaner: cleaner = DataCleaner( dj_config, - clean_strategy=self.data_config.clean_strategy, - min_size_ratio=self.data_config.min_size_ratio, - data_dist=self.data_config.data_dist, + clean_strategy=self.config.clean_strategy, + min_size_ratio=self.config.min_size_ratio, + data_dist=self.config.data_dist, ) if hit_synthesizer: synthesizer = DataSynthesizer( @@ -122,43 +119,61 @@ def run(self): traceback.print_exc() return 3, "DataCleaner loading failed." - # step 4. apply processors to calculate scores of different dimensions - try: - res_dataset = dataset - if hit_cleaner: - res_dataset = cleaner.process([res_dataset]) - if hit_synthesizer: - res_dataset = synthesizer.process([res_dataset]) - if hit_human_annotator: - res_dataset = human_annotator.process([res_dataset]) - except Exception: - traceback.print_exc() - return 4, "DataProcessors processing failed." - - # step 5. calculate the average and final scores, including priority - try: - if hit_cleaner: - scored_dataset = self._group_scores(res_dataset) - scored_dataset = self._compute_priority_scores(scored_dataset) - else: - scored_dataset = res_dataset - except Exception: - traceback.print_exc() - return 5, "Grouping and computing priority score failed." - - # step 6. track lineage if they are changed - try: - res_dataset = scored_dataset - except Exception: - traceback.print_exc() - return 6, "Tracking lineage failed." - - # step 7. export the result to the database - try: - self.dsdb.add_entries(res_dataset) - except Exception: - traceback.print_exc() - return 7, "Exporting result to database failed." + while True: + # step 4. load data from the input buffers for the next batch + try: + dataset.read_from_buffer() + except StopIteration: + break + except Exception: + traceback.print_exc() + return 4, "RftDataset loading from buffers failed." + + # step 5. apply processors to calculate scores of different dimensions + try: + res_dataset = dataset + if hit_cleaner: + res_dataset = cleaner.process([res_dataset]) + if hit_synthesizer: + res_dataset = synthesizer.process([res_dataset]) + if hit_human_annotator: + res_dataset = human_annotator.process([res_dataset]) + except Exception: + traceback.print_exc() + return 5, "DataProcessors processing failed." + + # step 6. calculate the average and final scores, including priority + try: + if hit_cleaner: + scored_dataset = self._group_scores(res_dataset) + scored_dataset = self._compute_priority_scores(scored_dataset) + else: + scored_dataset = res_dataset + except Exception: + traceback.print_exc() + return 6, "Grouping and computing priority score failed." + + # step 7. track lineage if they are changed + try: + res_dataset = scored_dataset + except Exception: + traceback.print_exc() + return 7, "Tracking lineage failed." + + # step 8 + try: + if "priority" in res_dataset.data.features: + res_dataset.sort_by("priority", reverse=True) + except Exception: + traceback.print_exc() + return 8, "Sorting results by priority failed." + + # step 9. sort and export the result to the output buffer + try: + res_dataset.write_to_buffer() + except Exception: + traceback.print_exc() + return 9, "Exporting result to output buffer failed." return 0, "success" @@ -171,7 +186,8 @@ def _group_scores(self, dataset: RftDataset) -> RftDataset: all_stats = [ sample[Fields.stats][stats] for sample in dataset.data if Fields.stats in sample ] - stats_min_max[stats] = [min(all_stats), max(all_stats)] + if len(all_stats) > 0 and isinstance(all_stats[0], Number): + stats_min_max[stats] = [min(all_stats), max(all_stats)] def _group_single(sample): stats = sample[Fields.stats] @@ -240,7 +256,7 @@ def _compute_combined_score( difficulty = stats.get("difficulty_score", 0.5) score += self.priority_weights["difficulty"] * difficulty - sample["priority"] = [score] + sample["priority"] = [score] if isinstance(sample[Fields.stats], list) else score return sample def _compute_diversity_score(self) -> float: @@ -252,10 +268,6 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset: dataset.data = dataset.data.map(self._compute_combined_score) return dataset - def _select_top_k(self, dataset: RftDataset, k: int) -> List: - """Select top-k samples based on utility scores""" - return dataset.data.sort("priority", reverse=True).take(k).to_list() - @ray.method(num_returns=1) def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]: """Select a batch of samples for training""" @@ -267,7 +279,8 @@ def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, A dataset.data = dataset.data.filter(lambda s: s["priority"] >= self.min_priority_score) # Select top-k samples - selected_samples = self._select_top_k(dataset, batch_size) + dataset.sort_by("priority", reverse=True, top_k=batch_size) + selected_samples = dataset.data.to_list() # Update state self._update_state(selected_samples, dataset.data["priority"]) diff --git a/trinity/data/controllers/task_parser.py b/trinity/data/controllers/task_parser.py index 23b169ab2d..2e30dace63 100644 --- a/trinity/data/controllers/task_parser.py +++ b/trinity/data/controllers/task_parser.py @@ -7,7 +7,7 @@ from jsonargparse import Namespace from loguru import logger -from trinity.common.config import Config +from trinity.common.config import DataPipelineConfig from trinity.data.core.dataset import RftDataset from .default_ops import ( @@ -128,7 +128,7 @@ class DataTaskParser: def __init__( self, - rft_config: Config, + data_pipeline_config: DataPipelineConfig, llm_agent: DashScopeChatWrapper = None, dataset: RftDataset = None, validate_config: bool = True, @@ -136,12 +136,12 @@ def __init__( """ Initialization method. - :param rft_config: All configs. + :param data_pipeline_config: All configs of specified data pipeline. :param llm_agent: The LLM agent for natural language parsing. :param dataset: The dataset to be processed. :param validate_config: If execute the config validation check. """ - self.config = rft_config.data + self.config = data_pipeline_config self.llm_agent = llm_agent self.validate_config = validate_config # TODO: refer dataset to support natural language parsing. @@ -164,15 +164,21 @@ def parse_to_dj_config(self, extra_op_args=None): return dj_config, hit_cleaner, hit_synthesizer, hit_human_annotator def _check_types_of_processors(self, dj_config): + if dj_config is None: + return False, False, False hit_cleaner, hit_synthesizer, hit_human_annotator = False, False, False - for op in dj_config.process: + process_list = dj_config.get("process", []) + for op in process_list: op_name = list(op.keys())[0] - if op_name in DEFAULT_CLEANER: - hit_cleaner = True - elif op_name in DEFAULT_SYNTHESIZER: + if op_name in DEFAULT_SYNTHESIZER: hit_synthesizer = True elif op_name in DEFAULT_HUMAN_ANNOTATOR: hit_human_annotator = True + else: + for dimension in DEFAULT_CLEANER: + if op_name in DEFAULT_CLEANER[dimension]: + hit_cleaner = True + break return hit_cleaner, hit_synthesizer, hit_human_annotator def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> Namespace: @@ -185,20 +191,10 @@ def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> N print(op) return dj_config - def _add_extra_args(self, dj_config: Namespace, op_weights: Dict = {}) -> Namespace: - """Add extra argument for RFT project""" - for op in dj_config.process: - op_name = list(op.keys())[0] - if "op_weight" not in op[op_name]: - op[op_name]["op_weight"] = op_weights[op_name] if op_name in op_weights else 1 - op[op_name]["op_weight"] = max(0, op[op_name]["op_weight"]) - return dj_config - def _direct_mapping(self) -> Namespace: """Direct mapping from RFT config to DJ config""" dj_config = prepare_side_configs(self.config.dj_config_path) dj_config = get_init_configs(dj_config) - dj_config = self._add_extra_args(dj_config) return dj_config def _agent_based_parsing(self, extra_op_args=None, try_num=3) -> Namespace: @@ -251,13 +247,11 @@ def _parse_llm_response(self, response: ModelResponse, extra_op_args=None): other_op_args = DEFAULT_OP_ARGS dj_process = [] - op_weights = {} def json_to_dj_config(parsed_json): for dim in set(parsed_json.keys()) & set(cleaners.keys()): for op_name in set(parsed_json[dim].keys()) & set(cleaners[dim].keys()): dj_process.append({op_name: {}}) - op_weights[op_name] = float(parsed_json[dim][op_name]) json_match = re.search(r"```json\n(.*?)\n```", response.text, re.DOTALL) if json_match: @@ -284,20 +278,5 @@ def json_to_dj_config(parsed_json): op[op_name][key] = val dj_config = Namespace(process=dj_process) dj_config = get_init_configs(dj_config) - dj_config = self._add_extra_args(dj_config, op_weights) - - if self.validate_config and not self._validate_config(dj_config): - return None return dj_config - - def _validate_config(self, config: Namespace) -> bool: - """Validate generated DJ config""" - try: - for op in config.process: - op_name = list(op.keys())[0] - weight = float(op[op_name]["op_weight"]) - assert 0 <= weight and weight <= 1 - except Exception: - return False - return True diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index 3e4af0fe12..93be832cc7 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -3,13 +3,10 @@ from typing import Any, Dict, List, Optional, Union import networkx as nx -from data_juicer.core.data.dj_dataset import Dataset -from datasets import load_dataset +from datasets import Dataset, concatenate_datasets -from trinity.common.config import DataProcessorConfig -from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.task import TaskSet -from trinity.common.workflows import WORKFLOWS +from trinity.buffer import get_buffer_reader, get_buffer_writer +from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig from trinity.data.core.formatter import BaseDataFormatter @@ -31,25 +28,27 @@ class RftDataset: 4. Basic statistics and metrics computation Args: - config (Dict): Configuration dict including DJ config + data_pipeline_config (DataPipelineConfig): Configuration including DJ config reward_schema (Union[str, Dict]): Schema definition for reward fields track_lineage (bool): Whether to track data lineage """ def __init__( self, - data_config: DataProcessorConfig, + data_pipeline_config: DataPipelineConfig, + buffer_config: BufferConfig = None, reward_schema: Union[str, Dict] = "default", track_lineage: bool = True, ): - self.config = data_config - source_data_path = data_config.source_data_path - if not source_data_path: - raise ValueError("source_data_path is not specified in DJ config") - load_kwargs = data_config.load_kwargs - self.data = load_dataset(source_data_path, trust_remote_code=True, **load_kwargs) - - self.format = data_config.format + self.config = data_pipeline_config + self.buffer_config = buffer_config + input_buffer_configs = self.config.input_buffers + if len(input_buffer_configs) == 0: + raise ValueError("input_buffers is empty in data pipeline config") + self.buffers = [] + for input_buffer_config in input_buffer_configs: + self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config)) + self.data = Dataset.from_list([]) self.reward_schema = self._init_reward_schema(reward_schema) self.stats: Dict[str, Any] = {} @@ -65,15 +64,28 @@ def format( for formatter in formatters: self.data = formatter(self.data, num_proc) - def to_taskset(self, **kwargs) -> TaskSet: - default_workflow_cls = WORKFLOWS.get(self.config.default_workflow_type) - default_reward_fn_cls = REWARD_FUNCTIONS.get(self.config.default_reward_fn_type) - return TaskSet( - dataset=self.data, - config=self.config, - workflow=default_workflow_cls, - reward_fn=default_reward_fn_cls, - ) + def sort_by(self, key: str, reverse: bool = False, top_k: int = -1): + if top_k == -1: + top_k = len(self.data) + self.data = self.data.sort(key, reverse=reverse).take(top_k) + + def read_from_buffer(self): + datasets = [] + for buffer in self.buffers: + datasets.append(Dataset.from_list(buffer.read())) + self.data = concatenate_datasets(datasets) + + def write_to_buffer( + self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None + ): + if output_storage_config is None: + output_storage_config = self.config.output_buffer + if buffer_config is None: + buffer_config = self.buffer_config + output_buffer = get_buffer_writer(output_storage_config, buffer_config) + output_buffer.write(self.data.to_list()) + output_buffer.finish() + self.data = Dataset.from_list([]) def to_parquet(self, path: str): self.data.to_parquet(path) diff --git a/trinity/data/core/dataset_db.py b/trinity/data/core/dataset_db.py deleted file mode 100644 index f47b138995..0000000000 --- a/trinity/data/core/dataset_db.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import List - -from sqlalchemy import asc, create_engine, desc -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool - -from trinity.buffer.utils import retry_session -from trinity.common.config import DataProcessorConfig -from trinity.common.schema import Base, RftDatasetModel -from trinity.data.core.dataset import RftDataset -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -def rft_dataset_to_model(dataset: RftDataset) -> List[RftDatasetModel]: - # hit keys of schema - hit_schema_keys = [] - hit_dataset_keys = [] - # get hit keys & vals - # - for content keys, we need to map it with content_key_mapping and try to - # find them in the dataset - # - for other keys, we just need to check if they are in the dataset - data = dataset.data - features = data.features - content_key_mapping = dataset.format.__dict__ - schema_keys = {key for key in RftDatasetModel.__dict__.keys() if not key.startswith("_")} - for schema_key in schema_keys: - key = schema_key - if f"{schema_key}_key" in content_key_mapping: - key = content_key_mapping[f"{schema_key}_key"] - if key in features: - hit_schema_keys.append(schema_key) - hit_dataset_keys.append(key) - # construct entries - entries = [] - for sample in data: - valid_data = { - schema_key: sample[key] for schema_key, key in zip(hit_schema_keys, hit_dataset_keys) - } - entries.append(RftDatasetModel(**valid_data)) - return entries - - -class RftDatasetDB: - def __init__(self, config: DataProcessorConfig) -> None: - self.db_url = config.db_url - self.engine = create_engine(self.db_url, poolclass=NullPool) - self.config = config - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - self.session = sessionmaker(bind=self.engine) - - def add_entries(self, dataset: RftDataset): - with retry_session( - self, self.config.max_retry_times, self.config.max_retry_interval - ) as session: - session.add_all(rft_dataset_to_model(dataset)) - - def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False): - # get num_entries entries from the database - if order_by is not None and hasattr(RftDatasetModel, order_by): - order_by_key = getattr(RftDatasetModel, order_by) - order_by_key = asc(order_by_key) if ascending else desc(order_by_key) - else: - order_by_key = None - with retry_session( - self, self.config.max_retry_times, self.config.max_retry_interval - ) as session: - entries = ( - session.query(RftDatasetModel) - .order_by(order_by_key) - .limit(num_entries) - .with_for_update() - .all() - ) - - for entry in entries: - entry.consumed_cnt += 1 - samples = [entry.to_dict() for entry in entries] - return samples diff --git a/trinity/data/processors/cleaner.py b/trinity/data/processors/cleaner.py index b031e528e1..10979990b1 100644 --- a/trinity/data/processors/cleaner.py +++ b/trinity/data/processors/cleaner.py @@ -36,6 +36,7 @@ def __init__( clean_strategy: str = "iterative", min_size_ratio: PositiveFloat = None, data_dist: str = "gaussian", + op_weights: dict = None, **kwargs, ): """ @@ -54,6 +55,7 @@ def __init__( self.min_size_ratio = min_size_ratio self.data_dist = data_dist self.op_name_to_stats_key = {} + self.op_weights = op_weights def keep_cleaner_op_cfg(self, dj_cfg): """Only consider cleaner op in data-juicer configs.""" @@ -112,7 +114,7 @@ def update_op_threshold( update_record = {} for process in exe_cfg.process: op_name, args = list(process.items())[0] - op_weight = args["op_weight"] + op_weight = self.op_weights.get(op_name, 1) update_record[op_name] = {} temp_args = copy.deepcopy(args) @@ -164,7 +166,7 @@ def process( else: logger.info("Executing Data-Juicer analyzer...") analyzer = Analyzer(self.dj_cfg) - analyzer.run(dataset) + analyzer.run(dataset, skip_export=True) df = analyzer.overall_result mean_series = df[df.index == "mean"] stats_key_to_mean = mean_series.iloc[0, :].to_dict() diff --git a/trinity/data/readme.md b/trinity/data/readme.md index 3294819f43..4b5c828ee6 100644 --- a/trinity/data/readme.md +++ b/trinity/data/readme.md @@ -88,14 +88,14 @@ synth_data = synthesizer.process(clean_data) - Then you need to prepare the `data_processor` section in the config file (e.g. [test_cfg.yaml](tests/test_configs/active_iterator_test_cfg.yaml)) - For the `dj_config_path` argument in it, you can either specify a data-juicer config file path (e.g. [test_dj_cfg.yaml](tests/test_configs/active_iterator_test_dj_cfg.yaml)), or write the demand in `dj_process_desc` argument in natural language and our agent will help you to organize the data-juicer config. - Finally you can send requests to the data server to start an active iterator to process datasets in many ways: - - Request with `curl`: `curl "http://127.0.0.1:5000/data_workflow?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"` + - Request with `curl`: `curl "http://127.0.0.1:5005/data_processor/task_pipeline?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"` - Request using our simple client: ```python from trinity.cli.client import request res = request( - url="http://127.0.0.1:5005/data_workflow", + url="http://127.0.0.1:5005/data_processor/task_pipeline", configPath="tests/test_configs/active_iterator_test_cfg.yaml" ) diff --git a/trinity/data/server.py b/trinity/data/server.py index 08ca5ebfea..e1f57ba81b 100644 --- a/trinity/data/server.py +++ b/trinity/data/server.py @@ -1,20 +1,39 @@ import fire from flask import Flask, jsonify, request +from markupsafe import escape app = Flask(__name__) -APP_NAME = "data_workflow" +APP_NAME = "data_processor" -@app.route(f"/{APP_NAME}", methods=["GET"]) -def data_workflow(): +@app.route(f"/{APP_NAME}/", methods=["GET"]) +def data_processor(pipeline_type): from trinity.common.config import load_config from trinity.data.controllers.active_iterator import DataActiveIterator config_path = request.args.get("configPath") + pipeline_type = escape(pipeline_type) config = load_config(config_path) - iterator = DataActiveIterator(config) + pipeline_config = getattr(config.data_processor, pipeline_type) + if pipeline_config is None: + return jsonify( + { + "return_code": -1, + "message": f"Error: {pipeline_type} is not supported or the corresponding config is empty", + } + ) + + if pipeline_config.dj_config_path is None and pipeline_config.dj_process_desc is None: + return jsonify( + { + "return_code": -1, + "message": "Error: Both dj_config_path and dj_process_desc in the pipeline config are None.", + } + ) + + iterator = DataActiveIterator(pipeline_config, config.buffer) ret, msg = iterator.run() return jsonify({"return_code": ret, "message": msg})