Skip to content

Commit 7511c96

Browse files
committed
* update .gitignore
1 parent 2bdcdea commit 7511c96

File tree

1 file changed

+56
-2
lines changed

1 file changed

+56
-2
lines changed

trinity/cli/launcher.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import ray
99

10-
from trinity.common.config import Config, load_config
10+
from trinity.common.config import Config, DataPipelineConfig, load_config
1111
from trinity.common.constants import AlgorithmType
1212
from trinity.explorer.explorer import Explorer
1313
from trinity.trainer.trainer import Trainer
@@ -158,17 +158,71 @@ def activate_data_module(data_workflow_url: str, config_path: str):
158158
return
159159

160160

161+
def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str):
162+
"""
163+
Check if the data pipeline is valid. The config should:
164+
1. Non-empty input buffer
165+
2. Different input/output buffers
166+
167+
:param data_pipeline_config: the input data pipeline to be validated.
168+
:param pipeline_type: the type of pipeline, should be one of ["task", "experience"]
169+
"""
170+
input_buffers = data_pipeline_config.input_buffers
171+
output_buffer = data_pipeline_config.output_buffer
172+
# common checks
173+
# check if the input buffer list is empty
174+
if len(input_buffers) == 0:
175+
logger.warning("Empty input buffers in the data pipeline. Won't activate it.")
176+
return False
177+
# check if the input and output buffers are different
178+
input_buffer_names = [buffer.name for buffer in input_buffers]
179+
if output_buffer.name in input_buffer_names:
180+
logger.warning("Output buffer exists in input buffers. Won't activate it.")
181+
return False
182+
if pipeline_type == "task":
183+
# task pipeline specific
184+
# "raw" field should be True for task pipeline because the data source must be raw data files
185+
for buffer in input_buffers:
186+
if not buffer.raw:
187+
logger.warning(
188+
'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.'
189+
)
190+
return False
191+
elif pipeline_type == "experience":
192+
# experience pipeline specific
193+
pass
194+
else:
195+
logger.warning(
196+
f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].'
197+
)
198+
return False
199+
return True
200+
201+
161202
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
162203
load_plugins(plugin_dir)
163204
config = load_config(config_path)
164205
config.check_and_update()
165206
pprint(config)
166207
# try to activate task pipeline for raw data
167208
data_processor_config = config.data_processor
168-
if data_processor_config.data_workflow_url and data_processor_config.task_pipeline:
209+
if (
210+
data_processor_config.data_workflow_url
211+
and data_processor_config.task_pipeline
212+
and validate_data_pipeline(data_processor_config.task_pipeline, "task")
213+
):
169214
activate_data_module(
170215
f"{data_processor_config.data_workflow_url}/task_pipeline", config_path
171216
)
217+
# try to activate experience pipeline for experiences
218+
if (
219+
data_processor_config.data_workflow_url
220+
and data_processor_config.experience_pipeline
221+
and validate_data_pipeline(data_processor_config.experience_pipeline, "experience")
222+
):
223+
activate_data_module(
224+
f"{data_processor_config.data_workflow_url}/experience_pipeline", config_path
225+
)
172226
ray_namespace = f"{config.project}-{config.name}"
173227
if dlc:
174228
from trinity.utils.dlc_utils import setup_ray_cluster

0 commit comments

Comments
 (0)