@@ -82,12 +82,12 @@ def test(self):
8282 There are not tests by default.
8383 """
8484
85- class DaskParallelProcessor (BaseProcessor ):
85+ class BaseParallelProcessor (BaseProcessor ):
8686 """
8787 Processor class which allows operations on each entry to be parallelized using Dask.
8888
8989 Parallelization is done by distributing the workload using Dask bags inside
90- the :meth:`process` method.
90+ the :meth:`process` method.
9191
9292 Actual processing should be defined on a per-example basis inside the
9393 :meth:`process_dataset_entry` method.
@@ -127,16 +127,40 @@ def process(self):
127127 import multiprocessing
128128 import psutil
129129
130+ os .environ .setdefault ("PATH" , os .defpath )
131+ default_path = os .environ ["PATH" ]
130132
131133 self .prepare ()
132134 os .makedirs (os .path .dirname (self .output_manifest_file ), exist_ok = True )
133135 metrics = []
134136
135- # total line count.
136- with open (self .input_manifest_file , "rb" ) as f :
137- total_entries = sum (1 for _ in f )
138-
139- # check that we did not set different amount of cpu.
137+ # If no input manifest file is provided, use the read_manifest() method to get entries.
138+ if self .input_manifest_file is None :
139+ manifest_entries = list (self .read_manifest ())
140+ total_entries = len (manifest_entries )
141+ if total_entries == 0 :
142+ logger .info ("No input manifest entries found; using empty bag." )
143+ bag = db .from_sequence ([])
144+ else :
145+ bag = db .from_sequence (manifest_entries )
146+ else :
147+ # Check if input manifest file exists and is non-empty.
148+ if (not os .path .exists (self .input_manifest_file ) or os .path .getsize (self .input_manifest_file ) == 0 ):
149+ logger .info ("Input manifest file not found or empty; using empty bag." )
150+ total_entries = 0
151+ bag = db .from_sequence ([])
152+ else :
153+ # Compute total line count.
154+ with open (self .input_manifest_file , "rb" ) as f :
155+ total_entries = sum (1 for _ in f )
156+ if total_entries == 0 :
157+ logger .info ("Input manifest file is empty; using empty bag." )
158+ bag = db .from_sequence ([])
159+ else :
160+ bag = db .read_text (self .input_manifest_file , encoding = "utf-8" , blocksize = "8MB" )
161+ bag = bag .map (json .loads )
162+
163+ # Set up worker resources.
140164 num_cpus = multiprocessing .cpu_count () if self .max_workers == - 1 else self .max_workers
141165 total_memory = psutil .virtual_memory ().total
142166 mem_per_worker = total_memory // num_cpus
@@ -154,11 +178,14 @@ def process(self):
154178 "distributed.comm.timeouts.connect" : "30s" ,
155179 "distributed.comm.timeouts.tcp" : "30s"
156180 }):
157- client = Client (n_workers = num_cpus , processes = True , threads_per_worker = 2 , memory_limit = memory_limit )
181+ client = Client (
182+ n_workers = num_cpus ,
183+ processes = True ,
184+ threads_per_worker = 2 ,
185+ memory_limit = memory_limit ,
186+ env = {"PATH" : default_path }
187+ )
158188 try :
159- bag = db .read_text (self .input_manifest_file , encoding = "utf-8" , blocksize = "8MB" )
160- bag = bag .map (json .loads )
161-
162189 def process_partition (partition ):
163190 results = []
164191 for data_entry in partition :
@@ -175,13 +202,11 @@ def process_partition(partition):
175202 delayed_results = bag .to_delayed ()
176203 futures = client .compute (delayed_results )
177204
178- # Open output file before starting to process futures.
179205 with open (self .output_manifest_file , "wt" , encoding = "utf8" ) as fout , \
180206 tqdm (total = total_entries , desc = "Processing entries" , unit = "entry" ,
181207 mininterval = 0.5 , miniters = 10 ) as pbar :
182208 for future in as_completed (futures ):
183209 partition_result = future .result ()
184- # Write results immediately (as soon as available)
185210 for data_entry in partition_result :
186211 metrics .append (data_entry .metrics )
187212 if data_entry .data is None :
@@ -218,10 +243,14 @@ def _chunk_manifest(self):
218243 yield manifest_chunk
219244
220245 def read_manifest (self ):
221- """Read entries from the input manifest file."""
222- if not self .input_manifest_file :
223- raise ValueError ("Input manifest file is not specified." )
246+ """Read entries from the input manifest file.
224247
248+ Returns an iterator over JSON objects. If no input manifest file is set,
249+ returns an empty iterator.
250+ """
251+ if not self .input_manifest_file :
252+ logger .info ("No input manifest file specified; returning empty manifest iterator." )
253+ return iter ([])
225254 with open (self .input_manifest_file , "r" , encoding = "utf-8" ) as fin :
226255 for line in fin :
227256 yield json .loads (line )
@@ -249,9 +278,28 @@ def finalize(self, metrics: List[Any]):
249278 )
250279 logger .info ("Processor completed in (seconds): %.2f" , time .time () - self .start_time )
251280
281+ def test (self ):
282+ """Applies processing to each test case and raises an error if the output does not match expected output."""
283+ for test_case in self .test_cases :
284+ input_data = test_case ["input" ].copy () if isinstance (test_case ["input" ], dict ) else test_case ["input" ]
285+ generated_outputs = self .process_dataset_entry (input_data )
286+ expected_outputs = (
287+ [test_case ["output" ]] if not isinstance (test_case ["output" ], list )
288+ else test_case ["output" ]
289+ )
290+ for generated_output , expected_output in zip (generated_outputs , expected_outputs ):
291+ generated_data = generated_output .data if hasattr (generated_output , "data" ) else generated_output
292+ if generated_data != expected_output :
293+ raise RuntimeError (
294+ "Runtime test failed.\n "
295+ f"Test input: { test_case ['input' ]} \n "
296+ f"Generated output: { generated_data } \n "
297+ f"Expected output: { expected_output } "
298+ )
252299
253300
254- class BaseParallelProcessor (BaseProcessor ):
301+
302+ class LegacyParallelProcessor (BaseProcessor ):
255303 """Processor class which allows operations on each utterance to be parallelized.
256304
257305 Parallelization is done using ``tqdm.contrib.concurrent.process_map`` inside
0 commit comments