|
19 | 19 | import time |
20 | 20 | from abc import ABC, abstractmethod |
21 | 21 | from dataclasses import dataclass |
| 22 | +from joblib import Parallel, delayed |
22 | 23 | from typing import Any, Dict, List, Optional, Union |
23 | 24 |
|
24 | 25 | from tqdm import tqdm |
@@ -191,22 +192,20 @@ def _process_with_dask(self, metrics): |
191 | 192 | def _process_with_multiprocessing(self, metrics): |
192 | 193 | with open(self.output_manifest_file, "wt", encoding="utf8") as fout: |
193 | 194 | for manifest_chunk in self._chunk_manifest(): |
194 | | - data = itertools.chain( |
195 | | - *process_map( |
196 | | - self.process_dataset_entry, |
197 | | - manifest_chunk, |
198 | | - max_workers=self.max_workers, |
199 | | - chunksize=self.chunksize, |
200 | | - ) |
| 195 | + # Parallel processing using joblib |
| 196 | + results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")( |
| 197 | + delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk |
201 | 198 | ) |
202 | | - for data_entry in tqdm(data): |
203 | | - metrics.append(data_entry.metrics) |
204 | | - if data_entry.data is None: |
205 | | - continue |
206 | | - json.dump(data_entry.data, fout, ensure_ascii=False) |
207 | | - fout.write("\n") |
208 | | - self.number_of_entries += 1 |
209 | | - self.total_duration += data_entry.data.get("duration", 0) |
| 199 | + |
| 200 | + for result_group in tqdm(results): |
| 201 | + for data_entry in result_group: |
| 202 | + metrics.append(data_entry.metrics) |
| 203 | + if data_entry.data is None: |
| 204 | + continue |
| 205 | + json.dump(data_entry.data, fout, ensure_ascii=False) |
| 206 | + fout.write("\n") |
| 207 | + self.number_of_entries += 1 |
| 208 | + self.total_duration += data_entry.data.get("duration", 0) |
210 | 209 |
|
211 | 210 | def _chunk_manifest(self): |
212 | 211 | """Splits the input manifest into chunks of in_memory_chunksize size. |
@@ -379,24 +378,22 @@ def process(self): |
379 | 378 | metrics = [] |
380 | 379 | with open(self.output_manifest_file, "wt", encoding="utf8") as fout: |
381 | 380 | for manifest_chunk in self._chunk_manifest(): |
382 | | - # this will unroll all inner lists |
383 | | - data = itertools.chain( |
384 | | - *process_map( |
385 | | - self.process_dataset_entry, |
386 | | - manifest_chunk, |
387 | | - max_workers=self.max_workers, |
388 | | - chunksize=self.chunksize, |
389 | | - ) |
| 381 | + |
| 382 | + results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")( |
| 383 | + delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk |
390 | 384 | ) |
391 | | - for data_entry in tqdm(data): |
392 | | - if data_entry.metrics is not None: |
393 | | - pass # optionally accumulate metrics here |
394 | | - if data_entry.data is None: |
395 | | - continue |
396 | | - json.dump(data_entry.data, fout, ensure_ascii=False) |
397 | | - self.number_of_entries += 1 |
398 | | - self.total_duration += data_entry.data.get("duration", 0) |
399 | | - fout.write("\n") |
| 385 | + |
| 386 | + for result_group in tqdm(results): |
| 387 | + for data_entry in result_group: |
| 388 | + if data_entry.metrics is not None: |
| 389 | + pass # optionally accumulate metrics here |
| 390 | + if data_entry.data is None: |
| 391 | + continue |
| 392 | + json.dump(data_entry.data, fout, ensure_ascii=False) |
| 393 | + self.number_of_entries += 1 |
| 394 | + self.total_duration += data_entry.data.get("duration", 0) |
| 395 | + fout.write("\n") |
| 396 | + |
400 | 397 | self.finalize(self.test_cases) |
401 | 398 |
|
402 | 399 | def prepare(self): |
|
0 commit comments