Skip to content

Commit 062722f

Browse files
committed
* move data processor related funcs to data/utils.py
1 parent 17c91aa commit 062722f

File tree

2 files changed

+72
-69
lines changed

2 files changed

+72
-69
lines changed

trinity/cli/launcher.py

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import ray
1010

11-
from trinity.common.config import Config, DataPipelineConfig, load_config
11+
from trinity.common.config import Config, load_config
1212
from trinity.common.constants import DataProcessorPipelineType
1313
from trinity.explorer.explorer import Explorer
1414
from trinity.trainer.trainer import Trainer
15+
from trinity.data.utils import activate_data_processor, stop_data_processor, validate_data_pipeline
1516
from trinity.utils.log import get_logger
1617
from trinity.utils.plugin_loader import load_plugins
1718

@@ -147,74 +148,6 @@ def both(config: Config) -> None:
147148
explorer.shutdown.remote()
148149
trainer.shutdown.remote()
149150

150-
151-
def activate_data_processor(data_processor_url: str, config_path: str):
152-
"""Check whether to activate data module and preprocess datasets."""
153-
from trinity.cli.client import request
154-
155-
logger.info(f"Activating data module of {data_processor_url}...")
156-
res = request(
157-
url=data_processor_url,
158-
configPath=config_path,
159-
)
160-
if res["return_code"] != 0:
161-
logger.error(f"Failed to activate data module: {res['return_msg']}.")
162-
return
163-
164-
165-
def stop_data_processor(base_data_processor_url: str):
166-
"""Stop all pipelines in the data processor"""
167-
from trinity.cli.client import request
168-
169-
logger.info(f"Stopping all pipelines in {base_data_processor_url}...")
170-
res = request(url=f"{base_data_processor_url}/stop_all")
171-
if res["return_code"] != 0:
172-
logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.")
173-
return
174-
175-
176-
def validate_data_pipeline(
177-
data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType
178-
):
179-
"""
180-
Check if the data pipeline is valid. The config should:
181-
1. Non-empty input buffer
182-
2. Different input/output buffers
183-
184-
:param data_pipeline_config: the input data pipeline to be validated.
185-
:param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType
186-
"""
187-
input_buffers = data_pipeline_config.input_buffers
188-
output_buffer = data_pipeline_config.output_buffer
189-
# common checks
190-
# check if the input buffer list is empty
191-
if len(input_buffers) == 0:
192-
logger.warning("Empty input buffers in the data pipeline. Won't activate it.")
193-
return False
194-
# check if the input and output buffers are different
195-
input_buffer_names = [buffer.name for buffer in input_buffers]
196-
if output_buffer.name in input_buffer_names:
197-
logger.warning("Output buffer exists in input buffers. Won't activate it.")
198-
return False
199-
if pipeline_type == DataProcessorPipelineType.TASK:
200-
# task pipeline specific
201-
# "raw" field should be True for task pipeline because the data source must be raw data files
202-
for buffer in input_buffers:
203-
if not buffer.raw:
204-
logger.warning(
205-
'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.'
206-
)
207-
return False
208-
elif pipeline_type == DataProcessorPipelineType.EXPERIENCE:
209-
# experience pipeline specific
210-
# No special items need to be checked.
211-
pass
212-
else:
213-
logger.warning(f"Invalid pipeline type: {pipeline_type}..")
214-
return False
215-
return True
216-
217-
218151
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
219152
load_plugins(plugin_dir)
220153
config = load_config(config_path)

trinity/data/utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from trinity.common.config import DataPipelineConfig
2+
from trinity.common.constants import DataProcessorPipelineType
3+
4+
from trinity.utils.log import get_logger
5+
6+
logger = get_logger(__name__)
7+
8+
def activate_data_processor(data_processor_url: str, config_path: str):
9+
"""Check whether to activate data module and preprocess datasets."""
10+
from trinity.cli.client import request
11+
12+
logger.info(f"Activating data module of {data_processor_url}...")
13+
res = request(
14+
url=data_processor_url,
15+
configPath=config_path,
16+
)
17+
if res["return_code"] != 0:
18+
logger.error(f"Failed to activate data module: {res['return_msg']}.")
19+
return
20+
21+
def stop_data_processor(base_data_processor_url: str):
22+
"""Stop all pipelines in the data processor"""
23+
from trinity.cli.client import request
24+
25+
logger.info(f"Stopping all pipelines in {base_data_processor_url}...")
26+
res = request(url=f"{base_data_processor_url}/stop_all")
27+
if res["return_code"] != 0:
28+
logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.")
29+
return
30+
31+
def validate_data_pipeline(
32+
data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType
33+
):
34+
"""
35+
Check if the data pipeline is valid. The config should:
36+
1. Non-empty input buffer
37+
2. Different input/output buffers
38+
39+
:param data_pipeline_config: the input data pipeline to be validated.
40+
:param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType
41+
"""
42+
input_buffers = data_pipeline_config.input_buffers
43+
output_buffer = data_pipeline_config.output_buffer
44+
# common checks
45+
# check if the input buffer list is empty
46+
if len(input_buffers) == 0:
47+
logger.warning("Empty input buffers in the data pipeline. Won't activate it.")
48+
return False
49+
# check if the input and output buffers are different
50+
input_buffer_names = [buffer.name for buffer in input_buffers]
51+
if output_buffer.name in input_buffer_names:
52+
logger.warning("Output buffer exists in input buffers. Won't activate it.")
53+
return False
54+
if pipeline_type == DataProcessorPipelineType.TASK:
55+
# task pipeline specific
56+
# "raw" field should be True for task pipeline because the data source must be raw data files
57+
for buffer in input_buffers:
58+
if not buffer.raw:
59+
logger.warning(
60+
'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.'
61+
)
62+
return False
63+
elif pipeline_type == DataProcessorPipelineType.EXPERIENCE:
64+
# experience pipeline specific
65+
# No special items need to be checked.
66+
pass
67+
else:
68+
logger.warning(f"Invalid pipeline type: {pipeline_type}..")
69+
return False
70+
return True

0 commit comments

Comments
 (0)