Skip to content

Commit 989dc94

Browse files
committed
+ add model_params to text_tagging_by_prompt_mapper
+ flush the buffer when outputting the trace results and wait for 1 sec
1 parent 00f759e commit 989dc94

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

data_juicer/core/tracer/ray_tracer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,7 @@ def finalize_traces(self):
167167
# We'll use a generic name for now, could be improved with operator type detection
168168
res_name = self.get_trace_file_path(op_name)
169169
dif_df = pd.DataFrame(traces)
170-
dif_df.to_json(res_name, orient="records", lines=True, force_ascii=False)
170+
with open(res_name, "w") as out_buf:
171+
dif_df.to_json(out_buf, orient="records", lines=True, force_ascii=False)
172+
out_buf.flush()
171173
print(f"Exported {len(traces)} traced samples for op [{op_name}] to {res_name}")

data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
tensor_parallel_size: int = None,
7575
max_model_len: int = None,
7676
max_num_seqs: int = 256,
77+
model_params: Dict = None,
7778
sampling_params: Dict = None,
7879
*args,
7980
**kwargs,
@@ -93,6 +94,7 @@ def __init__(
9394
derived from the model config.
9495
:param max_num_seqs: It is only valid when enable_vllm is True.
9596
Maximum number of sequences to be processed in a single iteration.
97+
:param model_params: Parameters for model initialization.
9698
:param sampling_params: Sampling parameters for text generation.
9799
e.g {'temperature': 0.9, 'top_p': 0.95}
98100
:param args: extra args
@@ -117,7 +119,9 @@ def __init__(
117119
self.prompt = prompt
118120
self.tag_list = tag_list
119121
self.enable_vllm = enable_vllm
120-
model_params = {"trust_remote_code": trust_remote_code, "max_num_seqs": max_num_seqs}
122+
if model_params is None:
123+
model_params = {}
124+
model_params.update({"trust_remote_code": trust_remote_code, "max_num_seqs": max_num_seqs})
121125
if tensor_parallel_size is not None:
122126
model_params["tensor_parallel_size"] = tensor_parallel_size
123127
if max_model_len is not None:

tests/core/tracer/test_ray_tracer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33
import tempfile
44
import shutil
5+
import time
56
import jsonlines as jl
67
from data_juicer.core.tracer.ray_tracer import RayTracer
78
from data_juicer.utils.unittest_utils import TEST_TAG
@@ -58,6 +59,7 @@ def test_collect_mapper_sample_basic(self):
5859

5960
# Finalize traces to write to file
6061
ray.get(tracer.finalize_traces.remote())
62+
time.sleep(1)
6163

6264
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_mapper.jsonl')
6365
self.assertTrue(os.path.exists(trace_file_path))
@@ -87,6 +89,7 @@ def test_collect_mapper_sample_no_change(self):
8789

8890
# Finalize traces to write to file
8991
ray.get(tracer.finalize_traces.remote())
92+
time.sleep(1)
9093

9194
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_mapper.jsonl')
9295
# File should not exist since no samples were collected
@@ -105,6 +108,7 @@ def test_collect_mapper_sample_with_trace_keys(self):
105108

106109
# Finalize traces to write to file
107110
ray.get(tracer.finalize_traces.remote())
111+
time.sleep(1)
108112

109113
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_mapper.jsonl')
110114
self.assertTrue(os.path.exists(trace_file_path))
@@ -135,6 +139,7 @@ def test_collect_mapper_sample_with_missing_trace_keys(self):
135139

136140
# Finalize traces to write to file
137141
ray.get(tracer.finalize_traces.remote())
142+
time.sleep(1)
138143

139144
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_mapper.jsonl')
140145
self.assertTrue(os.path.exists(trace_file_path))
@@ -166,6 +171,7 @@ def test_collect_mapper_sample_not_in_op_list(self):
166171

167172
# Finalize traces to write to file
168173
ray.get(tracer.finalize_traces.remote())
174+
time.sleep(1)
169175

170176
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_mapper.jsonl')
171177
self.assertFalse(os.path.exists(trace_file_path))
@@ -183,6 +189,7 @@ def test_collect_filter_sample_basic(self):
183189

184190
# Finalize traces to write to file
185191
ray.get(tracer.finalize_traces.remote())
192+
time.sleep(1)
186193

187194
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_filter.jsonl')
188195
self.assertTrue(os.path.exists(trace_file_path))
@@ -208,6 +215,7 @@ def test_collect_filter_sample_should_keep(self):
208215

209216
# Finalize traces to write to file
210217
ray.get(tracer.finalize_traces.remote())
218+
time.sleep(1)
211219

212220
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_filter.jsonl')
213221
self.assertFalse(os.path.exists(trace_file_path))
@@ -225,6 +233,7 @@ def test_collect_filter_sample_not_in_op_list(self):
225233

226234
# Finalize traces to write to file
227235
ray.get(tracer.finalize_traces.remote())
236+
time.sleep(1)
228237

229238
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-test_filter.jsonl')
230239
self.assertFalse(os.path.exists(trace_file_path))
@@ -254,6 +263,7 @@ def test_collect_mapper_sample_show_num_limit(self):
254263

255264
# Finalize traces to write to file
256265
ray.get(tracer.finalize_traces.remote())
266+
time.sleep(1)
257267

258268
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-limited_mapper.jsonl')
259269
self.assertTrue(os.path.exists(trace_file_path))
@@ -288,6 +298,7 @@ def test_collect_filter_sample_show_num_limit(self):
288298

289299
# Finalize traces to write to file
290300
ray.get(tracer.finalize_traces.remote())
301+
time.sleep(1)
291302

292303
trace_file_path = os.path.join(self.work_dir, 'trace', 'sample_trace-limited_filter.jsonl')
293304
self.assertTrue(os.path.exists(trace_file_path))
@@ -327,6 +338,7 @@ def test_finalize_traces_empty(self):
327338

328339
# Don't collect anything, just finalize
329340
ray.get(tracer.finalize_traces.remote())
341+
time.sleep(1)
330342

331343
# No trace files should exist
332344
trace_dir = os.path.join(self.work_dir, 'trace')

tests/ops/mapper/test_text_tagging_by_prompt_mapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def test_tagging_vllm(self):
5252
enable_vllm=True,
5353
max_model_len=1024,
5454
max_num_seqs=16,
55-
sampling_params={'temperature': 0.1, 'top_p': 0.95, 'max_tokens': 256})
55+
sampling_params={'temperature': 0.1, 'top_p': 0.95, 'max_tokens': 256},
56+
model_params={'gpu_memory_utilization': 0.8},
57+
)
5658

5759

5860
if __name__ == '__main__':

0 commit comments

Comments
 (0)