@@ -82,7 +82,14 @@ def test_basic_tokenization_with_bos_eos(
8282
8383 # Mock single document
8484 mock_ds .__iter__ = lambda self : iter ([{"documents" : "Test document" }])
85- mock_ds .map = MagicMock ()
85+
86+ # Create filtered dataset mock
87+ filtered_ds = MagicMock ()
88+ filtered_ds .num_rows = 1
89+ filtered_ds .column_names = ["documents" ]
90+
91+ # Mock filter to return the filtered dataset
92+ mock_ds .filter = MagicMock (return_value = filtered_ds )
8693
8794 # Make map return a dataset with tokenized data
8895 def map_side_effect (func , ** kwargs ):
@@ -93,7 +100,7 @@ def map_side_effect(func, **kwargs):
93100 mapped_ds .to_json = MagicMock ()
94101 return mapped_ds
95102
96- mock_ds .map . side_effect = map_side_effect
103+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
97104 mock_load_dataset .return_value = mock_ds
98105
99106 # Run function
@@ -108,8 +115,8 @@ def map_side_effect(func, **kwargs):
108115 # Verify tokenizer was loaded
109116 mock_from_pretrained .assert_called_once_with ("test-model" )
110117
111- # Verify dataset map was called
112- assert mock_ds .map .called
118+ # Verify dataset filter and map were called
119+ assert mock_ds .filter .called
113120
114121 @patch ("instructlab.training.data_process.AutoTokenizer.from_pretrained" )
115122 @patch ("instructlab.training.data_process.load_dataset" )
@@ -127,6 +134,14 @@ def test_multiple_documents_separate_records(
127134
128135 docs = [{"documents" : "Doc 1" }, {"documents" : "Doc 2" }, {"documents" : "Doc 3" }]
129136
137+ # Create filtered dataset mock
138+ filtered_ds = MagicMock ()
139+ filtered_ds .num_rows = 3
140+ filtered_ds .column_names = ["documents" ]
141+
142+ # Mock filter to return the filtered dataset
143+ mock_ds .filter = MagicMock (return_value = filtered_ds )
144+
130145 # Mock map to process all documents
131146 def map_side_effect (func , ** kwargs ):
132147 results = [func (doc ) for doc in docs ]
@@ -136,7 +151,7 @@ def map_side_effect(func, **kwargs):
136151 mapped_ds .to_json = MagicMock ()
137152 return mapped_ds
138153
139- mock_ds .map . side_effect = map_side_effect
154+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
140155 mock_load_dataset .return_value = mock_ds
141156
142157 # Run
@@ -148,8 +163,8 @@ def map_side_effect(func, **kwargs):
148163 document_column_name = "documents" ,
149164 )
150165
151- # Verify map was called (which processes each document)
152- assert mock_ds .map .called
166+ # Verify filter and map were called (which processes each document)
167+ assert mock_ds .filter .called
153168
154169 @patch ("instructlab.training.data_process.load_dataset" )
155170 def test_empty_dataset_raises_error (self , mock_load_dataset , temp_output_dir ):
@@ -236,6 +251,14 @@ def test_statistics_logging(
236251 mock_ds .num_rows = 2
237252 mock_ds .column_names = ["documents" ]
238253
254+ # Create filtered dataset mock
255+ filtered_ds = MagicMock ()
256+ filtered_ds .num_rows = 2
257+ filtered_ds .column_names = ["documents" ]
258+
259+ # Mock filter to return the filtered dataset
260+ mock_ds .filter = MagicMock (return_value = filtered_ds )
261+
239262 # Mock map to return known lengths
240263 def map_side_effect (func , ** kwargs ):
241264 # Simulate 2 documents with 5 and 10 tokens each
@@ -245,7 +268,7 @@ def map_side_effect(func, **kwargs):
245268 mapped_ds .to_json = MagicMock ()
246269 return mapped_ds
247270
248- mock_ds .map . side_effect = map_side_effect
271+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
249272 mock_load_dataset .return_value = mock_ds
250273
251274 # Run
@@ -272,7 +295,14 @@ def test_parallel_processing(
272295 mock_ds = MagicMock ()
273296 mock_ds .num_rows = 1
274297 mock_ds .column_names = ["documents" ]
275- mock_ds .map = MagicMock ()
298+
299+ # Create filtered dataset mock
300+ filtered_ds = MagicMock ()
301+ filtered_ds .num_rows = 1
302+ filtered_ds .column_names = ["documents" ]
303+
304+ # Mock filter to return the filtered dataset
305+ mock_ds .filter = MagicMock (return_value = filtered_ds )
276306
277307 def map_side_effect (func , ** kwargs ):
278308 mapped_ds = MagicMock ()
@@ -281,7 +311,7 @@ def map_side_effect(func, **kwargs):
281311 mapped_ds .to_json = MagicMock ()
282312 return mapped_ds
283313
284- mock_ds .map . side_effect = map_side_effect
314+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
285315 mock_load_dataset .return_value = mock_ds
286316
287317 # Run with specific num_cpu_procs
@@ -293,9 +323,13 @@ def map_side_effect(func, **kwargs):
293323 document_column_name = "documents" ,
294324 )
295325
296- # Verify map was called with num_proc=4
297- call_args = mock_ds .map .call_args
298- assert call_args [1 ]["num_proc" ] == 4
326+ # Verify filter was called with num_proc=4
327+ filter_call_args = mock_ds .filter .call_args
328+ assert filter_call_args [1 ]["num_proc" ] == 4
329+
330+ # Verify map was also called with num_proc=4
331+ map_call_args = filtered_ds .map .call_args
332+ assert map_call_args [1 ]["num_proc" ] == 4
299333
300334 def test_output_directory_creation (self , tmp_path , mock_tokenizer ):
301335 """Verify directory is created if it doesn't exist."""
@@ -314,6 +348,14 @@ def test_output_directory_creation(self, tmp_path, mock_tokenizer):
314348 mock_ds .num_rows = 1
315349 mock_ds .column_names = ["documents" ]
316350
351+ # Create filtered dataset mock
352+ filtered_ds = MagicMock ()
353+ filtered_ds .num_rows = 1
354+ filtered_ds .column_names = ["documents" ]
355+
356+ # Mock filter to return the filtered dataset
357+ mock_ds .filter = MagicMock (return_value = filtered_ds )
358+
317359 def map_side_effect (func , ** kwargs ):
318360 mapped_ds = MagicMock ()
319361 mapped_ds .__len__ = lambda self : 1
@@ -323,7 +365,7 @@ def map_side_effect(func, **kwargs):
323365 mapped_ds .to_json = MagicMock ()
324366 return mapped_ds
325367
326- mock_ds .map . side_effect = map_side_effect
368+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
327369 mock_load_dataset .return_value = mock_ds
328370
329371 # Run
@@ -351,6 +393,14 @@ def test_output_jsonl_format(
351393 mock_ds .num_rows = 1
352394 mock_ds .column_names = ["documents" ]
353395
396+ # Create filtered dataset mock
397+ filtered_ds = MagicMock ()
398+ filtered_ds .num_rows = 1
399+ filtered_ds .column_names = ["documents" ]
400+
401+ # Mock filter to return the filtered dataset
402+ mock_ds .filter = MagicMock (return_value = filtered_ds )
403+
354404 # Track what gets written
355405 output_file_path = None
356406
@@ -371,7 +421,7 @@ def to_json_side_effect(path, **kw):
371421 mapped_ds .to_json = to_json_side_effect
372422 return mapped_ds
373423
374- mock_ds .map . side_effect = map_side_effect
424+ filtered_ds .map = MagicMock ( side_effect = map_side_effect )
375425 mock_load_dataset .return_value = mock_ds
376426
377427 # Run
0 commit comments