Skip to content

Commit 048e23c

Browse files
Use joblib for multiprocessing instead of itertools (#152)
* Use joblib for multiprocessing instead of itertools Signed-off-by: Sushmitha Deva <[email protected]> * Update base_processor.py Signed-off-by: Sushmitha Deva <[email protected]> --------- Signed-off-by: Sushmitha Deva <[email protected]>
1 parent 1731ef7 commit 048e23c

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

sdp/processors/base_processor.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass
22+
from joblib import Parallel, delayed
2223
from typing import Any, Dict, List, Optional, Union
2324

2425
from tqdm import tqdm
@@ -191,22 +192,20 @@ def _process_with_dask(self, metrics):
191192
def _process_with_multiprocessing(self, metrics):
192193
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
193194
for manifest_chunk in self._chunk_manifest():
194-
data = itertools.chain(
195-
*process_map(
196-
self.process_dataset_entry,
197-
manifest_chunk,
198-
max_workers=self.max_workers,
199-
chunksize=self.chunksize,
200-
)
195+
# Parallel processing using joblib
196+
results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")(
197+
delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk
201198
)
202-
for data_entry in tqdm(data):
203-
metrics.append(data_entry.metrics)
204-
if data_entry.data is None:
205-
continue
206-
json.dump(data_entry.data, fout, ensure_ascii=False)
207-
fout.write("\n")
208-
self.number_of_entries += 1
209-
self.total_duration += data_entry.data.get("duration", 0)
199+
200+
for result_group in tqdm(results):
201+
for data_entry in result_group:
202+
metrics.append(data_entry.metrics)
203+
if data_entry.data is None:
204+
continue
205+
json.dump(data_entry.data, fout, ensure_ascii=False)
206+
fout.write("\n")
207+
self.number_of_entries += 1
208+
self.total_duration += data_entry.data.get("duration", 0)
210209

211210
def _chunk_manifest(self):
212211
"""Splits the input manifest into chunks of in_memory_chunksize size.
@@ -379,24 +378,22 @@ def process(self):
379378
metrics = []
380379
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
381380
for manifest_chunk in self._chunk_manifest():
382-
# this will unroll all inner lists
383-
data = itertools.chain(
384-
*process_map(
385-
self.process_dataset_entry,
386-
manifest_chunk,
387-
max_workers=self.max_workers,
388-
chunksize=self.chunksize,
389-
)
381+
382+
results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")(
383+
delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk
390384
)
391-
for data_entry in tqdm(data):
392-
if data_entry.metrics is not None:
393-
pass # optionally accumulate metrics here
394-
if data_entry.data is None:
395-
continue
396-
json.dump(data_entry.data, fout, ensure_ascii=False)
397-
self.number_of_entries += 1
398-
self.total_duration += data_entry.data.get("duration", 0)
399-
fout.write("\n")
385+
386+
for result_group in tqdm(results):
387+
for data_entry in result_group:
388+
if data_entry.metrics is not None:
389+
pass # optionally accumulate metrics here
390+
if data_entry.data is None:
391+
continue
392+
json.dump(data_entry.data, fout, ensure_ascii=False)
393+
self.number_of_entries += 1
394+
self.total_duration += data_entry.data.get("duration", 0)
395+
fout.write("\n")
396+
400397
self.finalize(self.test_cases)
401398

402399
def prepare(self):

0 commit comments

Comments
 (0)