77import types
88from abc import abstractmethod
99from dataclasses import dataclass
10+ from datetime import datetime
1011from multiprocessing import Process , Queue
1112from queue import Empty
1213from shutil import copyfile , rmtree
1516from urllib import parse
1617
1718import torch
18- from tqdm .auto import tqdm
19+ from tqdm .auto import tqdm as _tqdm
1920
2021from lightning import seed_everything
2122from lightning .data .streaming import Cache
@@ -278,6 +279,7 @@ def __init__(
278279 error_queue : Queue ,
279280 stop_queue : Queue ,
280281 num_downloaders : int ,
282+ num_uploaders : int ,
281283 remove : bool ,
282284 ) -> None :
283285 """The BaseWorker is responsible to process the user data."""
@@ -290,18 +292,19 @@ def __init__(
290292 self .items = items
291293 self .num_items = len (self .items )
292294 self .num_downloaders = num_downloaders
295+ self .num_uploaders = num_uploaders
293296 self .remove = remove
294297 self .paths : List [List [str ]] = []
295298 self .remover : Optional [Process ] = None
296299 self .downloaders : List [Process ] = []
300+ self .uploaders : List [Process ] = []
297301 self .to_download_queues : List [Queue ] = []
302+ self .to_upload_queues : List [Queue ] = []
298303 self .stop_queue = stop_queue
299304 self .ready_to_process_queue : Queue = Queue ()
300305 self .remove_queue : Queue = Queue ()
301- self .upload_queue : Queue = Queue ()
302306 self .progress_queue : Queue = progress_queue
303307 self .error_queue : Queue = error_queue
304- self .uploader : Optional [Process ] = None
305308 self ._collected_items = 0
306309 self ._counter = 0
307310 self ._last_time = time ()
@@ -316,14 +319,14 @@ def run(self) -> None:
316319 traceback_format = traceback .format_exc ()
317320 print (traceback_format )
318321 self .error_queue .put (traceback_format )
319- print (f"Worker { self .worker_index } is done." )
322+ print (f"Worker { str ( _get_node_rank () * self .num_workers + self . worker_index ) } is done." )
320323
321324 def _setup (self ) -> None :
322325 self ._set_environ_variables ()
323326 self ._create_cache ()
324327 self ._collect_paths ()
325328 self ._start_downloaders ()
326- self ._start_uploader ()
329+ self ._start_uploaders ()
327330 self ._start_remover ()
328331
329332 def _loop (self ) -> None :
@@ -335,13 +338,19 @@ def _loop(self) -> None:
335338 if index is None :
336339 num_downloader_finished += 1
337340 if num_downloader_finished == self .num_downloaders :
341+ print (f"Worker { str (_get_node_rank () * self .num_workers + self .worker_index )} is terminating." )
342+
338343 if isinstance (self .data_recipe , DataChunkRecipe ):
339344 self ._handle_data_chunk_recipe_end ()
340345
341346 if self .output_dir .url if self .output_dir .url else self .output_dir .path :
342- assert self .uploader
343- self .upload_queue .put (None )
344- self .uploader .join ()
347+ # Inform the uploaders they are doing working
348+ for i in range (self .num_uploaders ):
349+ self .to_upload_queues [i ].put (None )
350+
351+ # Wait for them all to be finished
352+ for uploader in self .uploaders :
353+ uploader .join ()
345354
346355 if self .remove :
347356 assert self .remover
@@ -402,7 +411,7 @@ def _try_upload(self, filepath: Optional[str]) -> None:
402411 return
403412
404413 assert os .path .exists (filepath ), filepath
405- self .upload_queue .put (filepath )
414+ self .to_upload_queues [ self . _counter % self . num_uploaders ] .put (filepath )
406415
407416 def _collect_paths (self ) -> None :
408417 items = []
@@ -475,19 +484,24 @@ def _start_remover(self) -> None:
475484 )
476485 self .remover .start ()
477486
478- def _start_uploader (self ) -> None :
487+ def _start_uploaders (self ) -> None :
479488 if self .output_dir .path is None and self .output_dir .url is None :
480489 return
481- self .uploader = Process (
482- target = _upload_fn ,
483- args = (
484- self .upload_queue ,
485- self .remove_queue ,
486- self .cache_chunks_dir ,
487- self .output_dir ,
488- ),
489- )
490- self .uploader .start ()
490+
491+ for _ in range (self .num_uploaders ):
492+ to_upload_queue : Queue = Queue ()
493+ p = Process (
494+ target = _upload_fn ,
495+ args = (
496+ to_upload_queue ,
497+ self .remove_queue ,
498+ self .cache_chunks_dir ,
499+ self .output_dir ,
500+ ),
501+ )
502+ p .start ()
503+ self .uploaders .append (p )
504+ self .to_upload_queues .append (to_upload_queue )
491505
492506 def _handle_data_chunk_recipe (self , index : int ) -> None :
493507 try :
@@ -509,10 +523,10 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
509523 def _handle_data_chunk_recipe_end (self ) -> None :
510524 chunks_filepaths = self .cache .done ()
511525
512- if chunks_filepaths :
513- for chunk_filepath in chunks_filepaths :
526+ if chunks_filepaths and len ( self . to_upload_queues ) :
527+ for i , chunk_filepath in enumerate ( chunks_filepaths ) :
514528 if isinstance (chunk_filepath , str ) and os .path .exists (chunk_filepath ):
515- self .upload_queue .put (chunk_filepath )
529+ self .to_upload_queues [ i % self . num_uploaders ] .put (chunk_filepath )
516530
517531 def _handle_data_transform_recipe (self , index : int ) -> None :
518532 # Don't use a context manager to avoid deleting files that are being uploaded.
@@ -721,6 +735,7 @@ def __init__(
721735 output_dir : Optional [Union [str , Dir ]] = None ,
722736 num_workers : Optional [int ] = None ,
723737 num_downloaders : Optional [int ] = None ,
738+ num_uploaders : Optional [int ] = None ,
724739 delete_cached_files : bool = True ,
725740 fast_dev_run : Optional [Union [bool , int ]] = None ,
726741 random_seed : Optional [int ] = 42 ,
@@ -734,6 +749,7 @@ def __init__(
734749 output_dir: The path to where the output data are stored.
735750 num_workers: The number of worker threads to use.
736751 num_downloaders: The number of file downloaders to use.
752+ num_uploaders: The number of file uploaders to use.
737753 delete_cached_files: Whether to delete the cached files.
738754 fast_dev_run: Whether to run a quick dev run.
739755 random_seed: The random seed to be set before shuffling the data.
@@ -744,7 +760,8 @@ def __init__(
744760 self .input_dir = _resolve_dir (input_dir )
745761 self .output_dir = _resolve_dir (output_dir )
746762 self .num_workers = num_workers or (1 if fast_dev_run else (os .cpu_count () or 1 ) * 4 )
747- self .num_downloaders = num_downloaders or 1
763+ self .num_downloaders = num_downloaders or 2
764+ self .num_uploaders = num_uploaders or 5
748765 self .delete_cached_files = delete_cached_files
749766 self .fast_dev_run = _get_fast_dev_run () if fast_dev_run is None else fast_dev_run
750767 self .workers : Any = []
@@ -816,30 +833,43 @@ def run(self, data_recipe: DataRecipe) -> None:
816833
817834 current_total = 0
818835 has_failed = False
819- with tqdm (total = num_items , smoothing = 0 , position = - 1 , mininterval = 1 ) as pbar :
820- while True :
836+ pbar = _tqdm (
837+ desc = "Progress" ,
838+ total = num_items ,
839+ smoothing = 0 ,
840+ position = - 1 ,
841+ mininterval = 1 ,
842+ leave = True ,
843+ dynamic_ncols = True ,
844+ )
845+
846+ while True :
847+ try :
848+ error = self .error_queue .get (timeout = 0.001 )
849+ self ._exit_on_error (error )
850+ except Empty :
851+ assert self .progress_queue
821852 try :
822- error = self .error_queue .get (timeout = 0.001 )
823- self ._exit_on_error (error )
853+ index , counter = self .progress_queue .get (timeout = 0.001 )
824854 except Empty :
825- assert self . progress_queue
826- try :
827- index , counter = self .progress_queue . get ( timeout = 0.001 )
828- except Empty :
829- continue
830- self . workers_tracker [ index ] = counter
831- new_total = sum ( self . workers_tracker . values ())
832-
833- pbar . update ( new_total - current_total )
834- current_total = new_total
835- if current_total == num_items :
836- break
837-
838- # Exit early if all the workers are done.
839- # This means there were some kinda of errors.
840- if all ( not w . is_alive () for w in self . workers ):
841- has_failed = True
842- break
855+ continue
856+ self . workers_tracker [ index ] = counter
857+ new_total = sum ( self .workers_tracker . values () )
858+
859+ pbar . set_postfix ({ "time" : datetime . now (). strftime ( "%H:%M:%S.%f" )})
860+ pbar . update ( new_total - current_total )
861+
862+ current_total = new_total
863+ if current_total == num_items :
864+ break
865+
866+ # Exit early if all the workers are done.
867+ # This means there were some kinda of errors.
868+ if all ( not w . is_alive () for w in self . workers ):
869+ has_failed = True
870+ break
871+
872+ pbar . close ()
843873
844874 num_nodes = _get_num_nodes ()
845875 node_rank = _get_node_rank ()
@@ -896,6 +926,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
896926 self .error_queue ,
897927 stop_queues [- 1 ],
898928 self .num_downloaders ,
929+ self .num_uploaders ,
899930 self .delete_cached_files ,
900931 )
901932 worker .start ()
0 commit comments