|
7 | 7 |
|
8 | 8 | import ray |
9 | 9 |
|
10 | | -from trinity.common.config import Config, load_config |
| 10 | +from trinity.common.config import Config, DataPipelineConfig, load_config |
11 | 11 | from trinity.common.constants import AlgorithmType |
12 | 12 | from trinity.explorer.explorer import Explorer |
13 | 13 | from trinity.trainer.trainer import Trainer |
@@ -158,17 +158,71 @@ def activate_data_module(data_workflow_url: str, config_path: str): |
158 | 158 | return |
159 | 159 |
|
160 | 160 |
|
| 161 | +def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str): |
| 162 | + """ |
| 163 | + Check if the data pipeline is valid. The config should: |
| 164 | + 1. Non-empty input buffer |
| 165 | + 2. Different input/output buffers |
| 166 | +
|
| 167 | + :param data_pipeline_config: the input data pipeline to be validated. |
| 168 | + :param pipeline_type: the type of pipeline, should be one of ["task", "experience"] |
| 169 | + """ |
| 170 | + input_buffers = data_pipeline_config.input_buffers |
| 171 | + output_buffer = data_pipeline_config.output_buffer |
| 172 | + # common checks |
| 173 | + # check if the input buffer list is empty |
| 174 | + if len(input_buffers) == 0: |
| 175 | + logger.warning("Empty input buffers in the data pipeline. Won't activate it.") |
| 176 | + return False |
| 177 | + # check if the input and output buffers are different |
| 178 | + input_buffer_names = [buffer.name for buffer in input_buffers] |
| 179 | + if output_buffer.name in input_buffer_names: |
| 180 | + logger.warning("Output buffer exists in input buffers. Won't activate it.") |
| 181 | + return False |
| 182 | + if pipeline_type == "task": |
| 183 | + # task pipeline specific |
| 184 | + # "raw" field should be True for task pipeline because the data source must be raw data files |
| 185 | + for buffer in input_buffers: |
| 186 | + if not buffer.raw: |
| 187 | + logger.warning( |
| 188 | + 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' |
| 189 | + ) |
| 190 | + return False |
| 191 | + elif pipeline_type == "experience": |
| 192 | + # experience pipeline specific |
| 193 | + pass |
| 194 | + else: |
| 195 | + logger.warning( |
| 196 | + f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].' |
| 197 | + ) |
| 198 | + return False |
| 199 | + return True |
| 200 | + |
| 201 | + |
161 | 202 | def run(config_path: str, dlc: bool = False, plugin_dir: str = None): |
162 | 203 | load_plugins(plugin_dir) |
163 | 204 | config = load_config(config_path) |
164 | 205 | config.check_and_update() |
165 | 206 | pprint(config) |
166 | 207 | # try to activate task pipeline for raw data |
167 | 208 | data_processor_config = config.data_processor |
168 | | - if data_processor_config.data_workflow_url and data_processor_config.task_pipeline: |
| 209 | + if ( |
| 210 | + data_processor_config.data_workflow_url |
| 211 | + and data_processor_config.task_pipeline |
| 212 | + and validate_data_pipeline(data_processor_config.task_pipeline, "task") |
| 213 | + ): |
169 | 214 | activate_data_module( |
170 | 215 | f"{data_processor_config.data_workflow_url}/task_pipeline", config_path |
171 | 216 | ) |
| 217 | + # try to activate experience pipeline for experiences |
| 218 | + if ( |
| 219 | + data_processor_config.data_workflow_url |
| 220 | + and data_processor_config.experience_pipeline |
| 221 | + and validate_data_pipeline(data_processor_config.experience_pipeline, "experience") |
| 222 | + ): |
| 223 | + activate_data_module( |
| 224 | + f"{data_processor_config.data_workflow_url}/experience_pipeline", config_path |
| 225 | + ) |
172 | 226 | ray_namespace = f"{config.project}-{config.name}" |
173 | 227 | if dlc: |
174 | 228 | from trinity.utils.dlc_utils import setup_ray_cluster |
|
0 commit comments