Skip to content

Commit d9501cf

Browse files
committed
+ init ray in the same namespace for data processor
+ release output buffer after the active iterator is finished
1 parent d9d4773 commit d9501cf

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

trinity/data/controllers/active_iterator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ def run(self, thread_event: threading.Event = None):
221221
traceback.print_exc()
222222
return 10, "Exporting result to output buffer failed."
223223

224+
try:
225+
dataset.release_output_buffer()
226+
except Exception:
227+
traceback.print_exc()
228+
return -1, "Releasing output buffer failed."
229+
224230
return 0, "success"
225231

226232
def _group_scores(self, dataset: RftDataset) -> RftDataset:

trinity/data/core/dataset.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,16 @@ def __init__(
4949
):
5050
self.config = data_pipeline_config
5151
self.buffer_config = buffer_config
52+
# init input buffers
5253
input_buffer_configs = self.config.input_buffers
5354
if len(input_buffer_configs) == 0:
5455
raise ValueError("input_buffers is empty in data pipeline config")
55-
self.buffers = []
56+
self.input_buffers = []
5657
for input_buffer_config in input_buffer_configs:
57-
self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config))
58+
self.input_buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config))
59+
# init output buffer
60+
self.output_buffer = get_buffer_writer(self.config.output_buffer, self.buffer_config)
61+
5862
self.data = Dataset.from_list([])
5963
self.original_dataclass = None
6064

@@ -79,28 +83,23 @@ def sort_by(self, key: str, reverse: bool = False, top_k: int = -1):
7983

8084
def read_from_buffer(self):
8185
datasets = []
82-
for buffer in self.buffers:
86+
for buffer in self.input_buffers:
8387
exp_list = buffer.read()
8488
if self.original_dataclass is None:
8589
self.original_dataclass = exp_list[0].__class__
8690
datasets.append(Dataset.from_list([asdict(exp) for exp in exp_list]))
8791
self.data = concatenate_datasets(datasets)
8892
logger.info(f"Read {len(self.data)} samples from input buffers")
8993

90-
def write_to_buffer(
91-
self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None
92-
):
93-
if output_storage_config is None:
94-
output_storage_config = self.config.output_buffer
95-
if buffer_config is None:
96-
buffer_config = self.buffer_config
97-
output_buffer = get_buffer_writer(output_storage_config, buffer_config)
94+
def write_to_buffer(self):
9895
exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()]
99-
output_buffer.write(exp_list)
100-
output_buffer.release()
96+
self.output_buffer.write(exp_list)
10197
logger.info(f"Wrote {len(self.data)} samples to output buffer")
10298
self.data = Dataset.from_list([])
10399

100+
def release_output_buffer(self):
101+
self.output_buffer.release()
102+
104103
def to_parquet(self, path: str):
105104
self.data.to_parquet(path)
106105

trinity/data/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import fire
22
import threading
3+
import ray
34
from flask import Flask, jsonify, request
45
from markupsafe import escape
56
from typing import List
@@ -20,6 +21,9 @@ def data_processor(pipeline_type):
2021
config = load_config(config_path)
2122
config.check_and_update()
2223

24+
# init ray
25+
ray.init(namespace=config.ray_namespace)
26+
2327
pipeline_config = getattr(config.data_processor, pipeline_type)
2428
if pipeline_config is None:
2529
return jsonify(

0 commit comments

Comments
 (0)