Skip to content

Commit 77d39f6

Browse files
committed
fix tests
1 parent 0275455 commit 77d39f6

File tree

2 files changed

+82
-18
lines changed

2 files changed

+82
-18
lines changed

src/instructlab/training/data_process.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,7 @@ def process_documents_for_pretraining(
11531153
Pattern: Each document → [BOS][tokens][EOS]
11541154
11551155
Args:
1156-
data_path: Path to input JSONL with {"documents": "text"} format
1156+
data_path: Path to input JSONL with {"document": "text"} format
11571157
data_output_path: Directory for processed data output
11581158
model_path: Path to model/tokenizer
11591159
num_cpu_procs: Number of parallel processes
@@ -1200,11 +1200,25 @@ def tokenize_document(sample):
12001200
"len": len(input_ids),
12011201
}
12021202

1203-
tokenized_data = data.map(
1203+
# Filter out empty documents before tokenization
1204+
def filter_empty_documents(batch):
1205+
return [bool(doc) for doc in batch[document_column_name]]
1206+
1207+
filtered_data = data.filter(
1208+
filter_empty_documents,
1209+
batched=True,
1210+
num_proc=num_cpu_procs,
1211+
desc="Filtering empty documents",
1212+
)
1213+
1214+
dropped_count = data.num_rows - filtered_data.num_rows
1215+
if dropped_count > 0:
1216+
logger.info(f"Dropped {dropped_count:,} empty documents")
1217+
tokenized_data = filtered_data.map(
12041218
tokenize_document,
12051219
num_proc=num_cpu_procs,
12061220
desc="Tokenizing documents",
1207-
remove_columns=data.column_names,
1221+
remove_columns=filtered_data.column_names,
12081222
)
12091223

12101224
# Calculate statistics

tests/unit/test_pretraining_data_process.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,14 @@ def test_basic_tokenization_with_bos_eos(
8282

8383
# Mock single document
8484
mock_ds.__iter__ = lambda self: iter([{"documents": "Test document"}])
85-
mock_ds.map = MagicMock()
85+
86+
# Create filtered dataset mock
87+
filtered_ds = MagicMock()
88+
filtered_ds.num_rows = 1
89+
filtered_ds.column_names = ["documents"]
90+
91+
# Mock filter to return the filtered dataset
92+
mock_ds.filter = MagicMock(return_value=filtered_ds)
8693

8794
# Make map return a dataset with tokenized data
8895
def map_side_effect(func, **kwargs):
@@ -93,7 +100,7 @@ def map_side_effect(func, **kwargs):
93100
mapped_ds.to_json = MagicMock()
94101
return mapped_ds
95102

96-
mock_ds.map.side_effect = map_side_effect
103+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
97104
mock_load_dataset.return_value = mock_ds
98105

99106
# Run function
@@ -108,8 +115,8 @@ def map_side_effect(func, **kwargs):
108115
# Verify tokenizer was loaded
109116
mock_from_pretrained.assert_called_once_with("test-model")
110117

111-
# Verify dataset map was called
112-
assert mock_ds.map.called
118+
# Verify dataset filter and map were called
119+
assert mock_ds.filter.called
113120

114121
@patch("instructlab.training.data_process.AutoTokenizer.from_pretrained")
115122
@patch("instructlab.training.data_process.load_dataset")
@@ -127,6 +134,14 @@ def test_multiple_documents_separate_records(
127134

128135
docs = [{"documents": "Doc 1"}, {"documents": "Doc 2"}, {"documents": "Doc 3"}]
129136

137+
# Create filtered dataset mock
138+
filtered_ds = MagicMock()
139+
filtered_ds.num_rows = 3
140+
filtered_ds.column_names = ["documents"]
141+
142+
# Mock filter to return the filtered dataset
143+
mock_ds.filter = MagicMock(return_value=filtered_ds)
144+
130145
# Mock map to process all documents
131146
def map_side_effect(func, **kwargs):
132147
results = [func(doc) for doc in docs]
@@ -136,7 +151,7 @@ def map_side_effect(func, **kwargs):
136151
mapped_ds.to_json = MagicMock()
137152
return mapped_ds
138153

139-
mock_ds.map.side_effect = map_side_effect
154+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
140155
mock_load_dataset.return_value = mock_ds
141156

142157
# Run
@@ -148,8 +163,8 @@ def map_side_effect(func, **kwargs):
148163
document_column_name="documents",
149164
)
150165

151-
# Verify map was called (which processes each document)
152-
assert mock_ds.map.called
166+
# Verify filter and map were called (which processes each document)
167+
assert mock_ds.filter.called
153168

154169
@patch("instructlab.training.data_process.load_dataset")
155170
def test_empty_dataset_raises_error(self, mock_load_dataset, temp_output_dir):
@@ -236,6 +251,14 @@ def test_statistics_logging(
236251
mock_ds.num_rows = 2
237252
mock_ds.column_names = ["documents"]
238253

254+
# Create filtered dataset mock
255+
filtered_ds = MagicMock()
256+
filtered_ds.num_rows = 2
257+
filtered_ds.column_names = ["documents"]
258+
259+
# Mock filter to return the filtered dataset
260+
mock_ds.filter = MagicMock(return_value=filtered_ds)
261+
239262
# Mock map to return known lengths
240263
def map_side_effect(func, **kwargs):
241264
# Simulate 2 documents with 5 and 10 tokens each
@@ -245,7 +268,7 @@ def map_side_effect(func, **kwargs):
245268
mapped_ds.to_json = MagicMock()
246269
return mapped_ds
247270

248-
mock_ds.map.side_effect = map_side_effect
271+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
249272
mock_load_dataset.return_value = mock_ds
250273

251274
# Run
@@ -272,7 +295,14 @@ def test_parallel_processing(
272295
mock_ds = MagicMock()
273296
mock_ds.num_rows = 1
274297
mock_ds.column_names = ["documents"]
275-
mock_ds.map = MagicMock()
298+
299+
# Create filtered dataset mock
300+
filtered_ds = MagicMock()
301+
filtered_ds.num_rows = 1
302+
filtered_ds.column_names = ["documents"]
303+
304+
# Mock filter to return the filtered dataset
305+
mock_ds.filter = MagicMock(return_value=filtered_ds)
276306

277307
def map_side_effect(func, **kwargs):
278308
mapped_ds = MagicMock()
@@ -281,7 +311,7 @@ def map_side_effect(func, **kwargs):
281311
mapped_ds.to_json = MagicMock()
282312
return mapped_ds
283313

284-
mock_ds.map.side_effect = map_side_effect
314+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
285315
mock_load_dataset.return_value = mock_ds
286316

287317
# Run with specific num_cpu_procs
@@ -293,9 +323,13 @@ def map_side_effect(func, **kwargs):
293323
document_column_name="documents",
294324
)
295325

296-
# Verify map was called with num_proc=4
297-
call_args = mock_ds.map.call_args
298-
assert call_args[1]["num_proc"] == 4
326+
# Verify filter was called with num_proc=4
327+
filter_call_args = mock_ds.filter.call_args
328+
assert filter_call_args[1]["num_proc"] == 4
329+
330+
# Verify map was also called with num_proc=4
331+
map_call_args = filtered_ds.map.call_args
332+
assert map_call_args[1]["num_proc"] == 4
299333

300334
def test_output_directory_creation(self, tmp_path, mock_tokenizer):
301335
"""Verify directory is created if it doesn't exist."""
@@ -314,6 +348,14 @@ def test_output_directory_creation(self, tmp_path, mock_tokenizer):
314348
mock_ds.num_rows = 1
315349
mock_ds.column_names = ["documents"]
316350

351+
# Create filtered dataset mock
352+
filtered_ds = MagicMock()
353+
filtered_ds.num_rows = 1
354+
filtered_ds.column_names = ["documents"]
355+
356+
# Mock filter to return the filtered dataset
357+
mock_ds.filter = MagicMock(return_value=filtered_ds)
358+
317359
def map_side_effect(func, **kwargs):
318360
mapped_ds = MagicMock()
319361
mapped_ds.__len__ = lambda self: 1
@@ -323,7 +365,7 @@ def map_side_effect(func, **kwargs):
323365
mapped_ds.to_json = MagicMock()
324366
return mapped_ds
325367

326-
mock_ds.map.side_effect = map_side_effect
368+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
327369
mock_load_dataset.return_value = mock_ds
328370

329371
# Run
@@ -351,6 +393,14 @@ def test_output_jsonl_format(
351393
mock_ds.num_rows = 1
352394
mock_ds.column_names = ["documents"]
353395

396+
# Create filtered dataset mock
397+
filtered_ds = MagicMock()
398+
filtered_ds.num_rows = 1
399+
filtered_ds.column_names = ["documents"]
400+
401+
# Mock filter to return the filtered dataset
402+
mock_ds.filter = MagicMock(return_value=filtered_ds)
403+
354404
# Track what gets written
355405
output_file_path = None
356406

@@ -371,7 +421,7 @@ def to_json_side_effect(path, **kw):
371421
mapped_ds.to_json = to_json_side_effect
372422
return mapped_ds
373423

374-
mock_ds.map.side_effect = map_side_effect
424+
filtered_ds.map = MagicMock(side_effect=map_side_effect)
375425
mock_load_dataset.return_value = mock_ds
376426

377427
# Run

0 commit comments

Comments
 (0)