Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions sdp/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from joblib import Parallel, delayed
from typing import Any, Dict, List, Optional, Union

from tqdm import tqdm
Expand Down Expand Up @@ -191,22 +192,20 @@ def _process_with_dask(self, metrics):
def _process_with_multiprocessing(self, metrics):
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
# Parallel processing using joblib
results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")(
delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk
)
for data_entry in tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
fout.write("\n")
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)

for result_group in tqdm(results):
for data_entry in result_group:
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
fout.write("\n")
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)

def _chunk_manifest(self):
"""Splits the input manifest into chunks of in_memory_chunksize size.
Expand Down Expand Up @@ -379,24 +378,22 @@ def process(self):
metrics = []
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)

results = Parallel(n_jobs=self.max_workers, backend="multiprocessing")(
delayed(self.process_dataset_entry)(entry) for entry in manifest_chunk
)
for data_entry in tqdm(data):
if data_entry.metrics is not None:
pass # optionally accumulate metrics here
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
fout.write("\n")

for result_group in tqdm(results):
for data_entry in result_group:
if data_entry.metrics is not None:
pass # optionally accumulate metrics here
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
fout.write("\n")

self.finalize(self.test_cases)

def prepare(self):
Expand Down
Loading