|
8 | 8 |
|
9 | 9 | import ray |
10 | 10 |
|
11 | | -from trinity.common.config import Config, DataPipelineConfig, load_config |
| 11 | +from trinity.common.config import Config, load_config |
12 | 12 | from trinity.common.constants import DataProcessorPipelineType |
13 | 13 | from trinity.explorer.explorer import Explorer |
14 | 14 | from trinity.trainer.trainer import Trainer |
| 15 | +from trinity.data.utils import activate_data_processor, stop_data_processor, validate_data_pipeline |
15 | 16 | from trinity.utils.log import get_logger |
16 | 17 | from trinity.utils.plugin_loader import load_plugins |
17 | 18 |
|
@@ -147,74 +148,6 @@ def both(config: Config) -> None: |
147 | 148 | explorer.shutdown.remote() |
148 | 149 | trainer.shutdown.remote() |
149 | 150 |
|
150 | | - |
151 | | -def activate_data_processor(data_processor_url: str, config_path: str): |
152 | | - """Check whether to activate data module and preprocess datasets.""" |
153 | | - from trinity.cli.client import request |
154 | | - |
155 | | - logger.info(f"Activating data module of {data_processor_url}...") |
156 | | - res = request( |
157 | | - url=data_processor_url, |
158 | | - configPath=config_path, |
159 | | - ) |
160 | | - if res["return_code"] != 0: |
161 | | - logger.error(f"Failed to activate data module: {res['return_msg']}.") |
162 | | - return |
163 | | - |
164 | | - |
165 | | -def stop_data_processor(base_data_processor_url: str): |
166 | | - """Stop all pipelines in the data processor""" |
167 | | - from trinity.cli.client import request |
168 | | - |
169 | | - logger.info(f"Stopping all pipelines in {base_data_processor_url}...") |
170 | | - res = request(url=f"{base_data_processor_url}/stop_all") |
171 | | - if res["return_code"] != 0: |
172 | | - logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.") |
173 | | - return |
174 | | - |
175 | | - |
176 | | -def validate_data_pipeline( |
177 | | - data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType |
178 | | -): |
179 | | - """ |
180 | | - Check if the data pipeline is valid. The config should: |
181 | | - 1. Non-empty input buffer |
182 | | - 2. Different input/output buffers |
183 | | -
|
184 | | - :param data_pipeline_config: the input data pipeline to be validated. |
185 | | - :param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType |
186 | | - """ |
187 | | - input_buffers = data_pipeline_config.input_buffers |
188 | | - output_buffer = data_pipeline_config.output_buffer |
189 | | - # common checks |
190 | | - # check if the input buffer list is empty |
191 | | - if len(input_buffers) == 0: |
192 | | - logger.warning("Empty input buffers in the data pipeline. Won't activate it.") |
193 | | - return False |
194 | | - # check if the input and output buffers are different |
195 | | - input_buffer_names = [buffer.name for buffer in input_buffers] |
196 | | - if output_buffer.name in input_buffer_names: |
197 | | - logger.warning("Output buffer exists in input buffers. Won't activate it.") |
198 | | - return False |
199 | | - if pipeline_type == DataProcessorPipelineType.TASK: |
200 | | - # task pipeline specific |
201 | | - # "raw" field should be True for task pipeline because the data source must be raw data files |
202 | | - for buffer in input_buffers: |
203 | | - if not buffer.raw: |
204 | | - logger.warning( |
205 | | - 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' |
206 | | - ) |
207 | | - return False |
208 | | - elif pipeline_type == DataProcessorPipelineType.EXPERIENCE: |
209 | | - # experience pipeline specific |
210 | | - # No special items need to be checked. |
211 | | - pass |
212 | | - else: |
213 | | - logger.warning(f"Invalid pipeline type: {pipeline_type}..") |
214 | | - return False |
215 | | - return True |
216 | | - |
217 | | - |
218 | 151 | def run(config_path: str, dlc: bool = False, plugin_dir: str = None): |
219 | 152 | load_plugins(plugin_dir) |
220 | 153 | config = load_config(config_path) |
|
0 commit comments