Skip to content

Commit 69fc0b4

Browse files
tchatonawaelchlithomas
authored andcommitted
Add multiple uploaders to the map, optimize (#18989)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas <[email protected]> (cherry picked from commit 7288302)
1 parent 8d0830b commit 69fc0b4

File tree

2 files changed

+80
-47
lines changed

2 files changed

+80
-47
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import types
88
from abc import abstractmethod
99
from dataclasses import dataclass
10+
from datetime import datetime
1011
from multiprocessing import Process, Queue
1112
from queue import Empty
1213
from shutil import copyfile, rmtree
@@ -15,7 +16,7 @@
1516
from urllib import parse
1617

1718
import torch
18-
from tqdm.auto import tqdm
19+
from tqdm.auto import tqdm as _tqdm
1920

2021
from lightning import seed_everything
2122
from lightning.data.streaming import Cache
@@ -278,6 +279,7 @@ def __init__(
278279
error_queue: Queue,
279280
stop_queue: Queue,
280281
num_downloaders: int,
282+
num_uploaders: int,
281283
remove: bool,
282284
) -> None:
283285
"""The BaseWorker is responsible to process the user data."""
@@ -290,18 +292,19 @@ def __init__(
290292
self.items = items
291293
self.num_items = len(self.items)
292294
self.num_downloaders = num_downloaders
295+
self.num_uploaders = num_uploaders
293296
self.remove = remove
294297
self.paths: List[List[str]] = []
295298
self.remover: Optional[Process] = None
296299
self.downloaders: List[Process] = []
300+
self.uploaders: List[Process] = []
297301
self.to_download_queues: List[Queue] = []
302+
self.to_upload_queues: List[Queue] = []
298303
self.stop_queue = stop_queue
299304
self.ready_to_process_queue: Queue = Queue()
300305
self.remove_queue: Queue = Queue()
301-
self.upload_queue: Queue = Queue()
302306
self.progress_queue: Queue = progress_queue
303307
self.error_queue: Queue = error_queue
304-
self.uploader: Optional[Process] = None
305308
self._collected_items = 0
306309
self._counter = 0
307310
self._last_time = time()
@@ -316,14 +319,14 @@ def run(self) -> None:
316319
traceback_format = traceback.format_exc()
317320
print(traceback_format)
318321
self.error_queue.put(traceback_format)
319-
print(f"Worker {self.worker_index} is done.")
322+
print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is done.")
320323

321324
def _setup(self) -> None:
322325
self._set_environ_variables()
323326
self._create_cache()
324327
self._collect_paths()
325328
self._start_downloaders()
326-
self._start_uploader()
329+
self._start_uploaders()
327330
self._start_remover()
328331

329332
def _loop(self) -> None:
@@ -335,13 +338,19 @@ def _loop(self) -> None:
335338
if index is None:
336339
num_downloader_finished += 1
337340
if num_downloader_finished == self.num_downloaders:
341+
print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is terminating.")
342+
338343
if isinstance(self.data_recipe, DataChunkRecipe):
339344
self._handle_data_chunk_recipe_end()
340345

341346
if self.output_dir.url if self.output_dir.url else self.output_dir.path:
342-
assert self.uploader
343-
self.upload_queue.put(None)
344-
self.uploader.join()
347+
# Inform the uploaders they are doing working
348+
for i in range(self.num_uploaders):
349+
self.to_upload_queues[i].put(None)
350+
351+
# Wait for them all to be finished
352+
for uploader in self.uploaders:
353+
uploader.join()
345354

346355
if self.remove:
347356
assert self.remover
@@ -402,7 +411,7 @@ def _try_upload(self, filepath: Optional[str]) -> None:
402411
return
403412

404413
assert os.path.exists(filepath), filepath
405-
self.upload_queue.put(filepath)
414+
self.to_upload_queues[self._counter % self.num_uploaders].put(filepath)
406415

407416
def _collect_paths(self) -> None:
408417
items = []
@@ -475,19 +484,24 @@ def _start_remover(self) -> None:
475484
)
476485
self.remover.start()
477486

478-
def _start_uploader(self) -> None:
487+
def _start_uploaders(self) -> None:
479488
if self.output_dir.path is None and self.output_dir.url is None:
480489
return
481-
self.uploader = Process(
482-
target=_upload_fn,
483-
args=(
484-
self.upload_queue,
485-
self.remove_queue,
486-
self.cache_chunks_dir,
487-
self.output_dir,
488-
),
489-
)
490-
self.uploader.start()
490+
491+
for _ in range(self.num_uploaders):
492+
to_upload_queue: Queue = Queue()
493+
p = Process(
494+
target=_upload_fn,
495+
args=(
496+
to_upload_queue,
497+
self.remove_queue,
498+
self.cache_chunks_dir,
499+
self.output_dir,
500+
),
501+
)
502+
p.start()
503+
self.uploaders.append(p)
504+
self.to_upload_queues.append(to_upload_queue)
491505

492506
def _handle_data_chunk_recipe(self, index: int) -> None:
493507
try:
@@ -509,10 +523,10 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
509523
def _handle_data_chunk_recipe_end(self) -> None:
510524
chunks_filepaths = self.cache.done()
511525

512-
if chunks_filepaths:
513-
for chunk_filepath in chunks_filepaths:
526+
if chunks_filepaths and len(self.to_upload_queues):
527+
for i, chunk_filepath in enumerate(chunks_filepaths):
514528
if isinstance(chunk_filepath, str) and os.path.exists(chunk_filepath):
515-
self.upload_queue.put(chunk_filepath)
529+
self.to_upload_queues[i % self.num_uploaders].put(chunk_filepath)
516530

517531
def _handle_data_transform_recipe(self, index: int) -> None:
518532
# Don't use a context manager to avoid deleting files that are being uploaded.
@@ -721,6 +735,7 @@ def __init__(
721735
output_dir: Optional[Union[str, Dir]] = None,
722736
num_workers: Optional[int] = None,
723737
num_downloaders: Optional[int] = None,
738+
num_uploaders: Optional[int] = None,
724739
delete_cached_files: bool = True,
725740
fast_dev_run: Optional[Union[bool, int]] = None,
726741
random_seed: Optional[int] = 42,
@@ -734,6 +749,7 @@ def __init__(
734749
output_dir: The path to where the output data are stored.
735750
num_workers: The number of worker threads to use.
736751
num_downloaders: The number of file downloaders to use.
752+
num_uploaders: The number of file uploaders to use.
737753
delete_cached_files: Whether to delete the cached files.
738754
fast_dev_run: Whether to run a quick dev run.
739755
random_seed: The random seed to be set before shuffling the data.
@@ -744,7 +760,8 @@ def __init__(
744760
self.input_dir = _resolve_dir(input_dir)
745761
self.output_dir = _resolve_dir(output_dir)
746762
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
747-
self.num_downloaders = num_downloaders or 1
763+
self.num_downloaders = num_downloaders or 2
764+
self.num_uploaders = num_uploaders or 5
748765
self.delete_cached_files = delete_cached_files
749766
self.fast_dev_run = _get_fast_dev_run() if fast_dev_run is None else fast_dev_run
750767
self.workers: Any = []
@@ -816,30 +833,43 @@ def run(self, data_recipe: DataRecipe) -> None:
816833

817834
current_total = 0
818835
has_failed = False
819-
with tqdm(total=num_items, smoothing=0, position=-1, mininterval=1) as pbar:
820-
while True:
836+
pbar = _tqdm(
837+
desc="Progress",
838+
total=num_items,
839+
smoothing=0,
840+
position=-1,
841+
mininterval=1,
842+
leave=True,
843+
dynamic_ncols=True,
844+
)
845+
846+
while True:
847+
try:
848+
error = self.error_queue.get(timeout=0.001)
849+
self._exit_on_error(error)
850+
except Empty:
851+
assert self.progress_queue
821852
try:
822-
error = self.error_queue.get(timeout=0.001)
823-
self._exit_on_error(error)
853+
index, counter = self.progress_queue.get(timeout=0.001)
824854
except Empty:
825-
assert self.progress_queue
826-
try:
827-
index, counter = self.progress_queue.get(timeout=0.001)
828-
except Empty:
829-
continue
830-
self.workers_tracker[index] = counter
831-
new_total = sum(self.workers_tracker.values())
832-
833-
pbar.update(new_total - current_total)
834-
current_total = new_total
835-
if current_total == num_items:
836-
break
837-
838-
# Exit early if all the workers are done.
839-
# This means there were some kinda of errors.
840-
if all(not w.is_alive() for w in self.workers):
841-
has_failed = True
842-
break
855+
continue
856+
self.workers_tracker[index] = counter
857+
new_total = sum(self.workers_tracker.values())
858+
859+
pbar.set_postfix({"time": datetime.now().strftime("%H:%M:%S.%f")})
860+
pbar.update(new_total - current_total)
861+
862+
current_total = new_total
863+
if current_total == num_items:
864+
break
865+
866+
# Exit early if all the workers are done.
867+
# This means there were some kinda of errors.
868+
if all(not w.is_alive() for w in self.workers):
869+
has_failed = True
870+
break
871+
872+
pbar.close()
843873

844874
num_nodes = _get_num_nodes()
845875
node_rank = _get_node_rank()
@@ -896,6 +926,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
896926
self.error_queue,
897927
stop_queues[-1],
898928
self.num_downloaders,
929+
self.num_uploaders,
899930
self.delete_cached_files,
900931
)
901932
worker.start()

tests/tests_data/streaming/test_data_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
484484
delete_cached_files=delete_cached_files,
485485
fast_dev_run=fast_dev_run,
486486
output_dir=remote_output_dir,
487+
num_uploaders=1,
488+
num_downloaders=1,
487489
)
488490
data_processor.run(CustomDataChunkRecipe(chunk_size=2))
489491

@@ -508,6 +510,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
508510
data_processor = TestDataProcessor(
509511
input_dir=input_dir,
510512
num_workers=2,
513+
num_uploaders=1,
511514
num_downloaders=1,
512515
delete_cached_files=delete_cached_files,
513516
fast_dev_run=fast_dev_run,
@@ -668,7 +671,6 @@ def test_data_processing_map(monkeypatch, tmpdir):
668671

669672

670673
def optimize_fn(filepath):
671-
print(filepath)
672674
from PIL import Image
673675

674676
return [Image.open(filepath), os.path.basename(filepath)]

0 commit comments

Comments
 (0)