55import traceback
66import types
77from abc import abstractmethod
8+ from dataclasses import dataclass
89from multiprocessing import Process , Queue
910from queue import Empty
1011from shutil import copyfile , rmtree
1112from time import sleep , time
12- from typing import Any , Callable , Dict , List , Optional , Tuple , TypeVar
13+ from typing import Any , Callable , Dict , List , Optional , Tuple , TypeVar , Union
1314from urllib import parse
1415
1516import torch
2122 _BOTO3_AVAILABLE ,
2223 _DEFAULT_FAST_DEV_RUN_ITEMS ,
2324 _INDEX_FILENAME ,
24- _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 ,
25+ _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 ,
2526 _TORCH_GREATER_EQUAL_2_1_0 ,
2627)
2728from lightning .fabric .accelerators .cuda import is_cuda_available
3536if _TORCH_GREATER_EQUAL_2_1_0 :
3637 from torch .utils ._pytree import tree_flatten , tree_unflatten
3738
38- if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 :
39+ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 :
3940 from lightning_cloud .resolver import _LightningSrcResolver , _LightningTargetResolver
4041
4142if _BOTO3_AVAILABLE :
@@ -160,10 +161,14 @@ def _remove_target(input_dir: str, cache_dir: str, queue_in: Queue) -> None:
160161 # 3. Iterate through the paths and delete them sequentially.
161162 for path in paths :
162163 if input_dir :
163- cached_filepath = path .replace (input_dir , cache_dir )
164+ if not path .startswith (cache_dir ):
165+ path = path .replace (input_dir , cache_dir )
164166
165- if os .path .exists (cached_filepath ):
166- os .remove (cached_filepath )
167+ if os .path .exists (path ):
168+ os .remove (path )
169+
170+ elif os .path .exists (path ) and "s3_connections" not in path :
171+ os .remove (path )
167172
168173
169174def _upload_fn (upload_queue : Queue , remove_queue : Queue , cache_dir : str , remote_output_dir : str ) -> None :
@@ -387,7 +392,9 @@ def _collect_paths(self) -> None:
387392 }
388393
389394 if len (indexed_paths ) == 0 :
390- raise ValueError (f"The provided item { item } didn't contain any filepaths. { flattened_item } " )
395+ raise ValueError (
396+ f"The provided item { item } didn't contain any filepaths. The input_dir is { self .input_dir } ."
397+ )
391398
392399 paths = []
393400 for index , path in indexed_paths .items ():
@@ -548,7 +555,7 @@ def __init__(self) -> None:
548555 def _setup (self , name : Optional [str ]) -> None :
549556 self ._name = name
550557
551- def _done (self , delete_cached_files : bool , remote_output_dir : str ) -> None :
558+ def _done (self , delete_cached_files : bool , remote_output_dir : Any ) -> None :
552559 pass
553560
554561
@@ -578,7 +585,6 @@ def prepare_item(self, item_metadata: T) -> Any: # type: ignore
578585
579586 def _done (self , delete_cached_files : bool , remote_output_dir : str ) -> None :
580587 num_nodes = _get_num_nodes ()
581- assert self ._name
582588 cache_dir = _get_cache_dir (self ._name )
583589
584590 chunks = [file for file in os .listdir (cache_dir ) if file .endswith (".bin" )]
@@ -647,6 +653,14 @@ def prepare_item(self, output_dir: str, item_metadata: T) -> None: # type: igno
647653 """Use your item metadata to process your files and save the file outputs into `output_dir`."""
648654
649655
656+ @dataclass
657+ class PrettyDirectory :
658+ """Holds a directory and its URL."""
659+
660+ directory : str
661+ url : str
662+
663+
650664class DataProcessor :
651665 def __init__ (
652666 self ,
@@ -656,10 +670,11 @@ def __init__(
656670 num_downloaders : Optional [int ] = None ,
657671 delete_cached_files : bool = True ,
658672 src_resolver : Optional [Callable [[str ], Optional [str ]]] = None ,
659- fast_dev_run : Optional [bool ] = None ,
673+ fast_dev_run : Optional [Union [ bool , int ] ] = None ,
660674 remote_input_dir : Optional [str ] = None ,
661- remote_output_dir : Optional [str ] = None ,
675+ remote_output_dir : Optional [Union [ str , PrettyDirectory ] ] = None ,
662676 random_seed : Optional [int ] = 42 ,
677+ version : Optional [int ] = None ,
663678 ):
664679 """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
665680 training faster.
@@ -692,18 +707,22 @@ def __init__(
692707 self .remote_input_dir = (
693708 str (remote_input_dir )
694709 if remote_input_dir is not None
695- else ((self .src_resolver (input_dir ) if input_dir else None ) if self .src_resolver else None )
710+ else ((self .src_resolver (str ( input_dir ) ) if input_dir else None ) if self .src_resolver else None )
696711 )
697712 self .remote_output_dir = (
698713 remote_output_dir
699714 if remote_output_dir is not None
700- else (self .dst_resolver (name ) if self .dst_resolver else None )
715+ else (self .dst_resolver (name , version = version ) if self .dst_resolver else None )
701716 )
702717 if self .remote_output_dir :
703718 self .name = self ._broadcast_object (self .name )
704719 # Ensure the remote src dir is the same across all ranks
705720 self .remote_output_dir = self ._broadcast_object (self .remote_output_dir )
706- print (f"Storing the files under { self .remote_output_dir } " )
721+ if isinstance (self .remote_output_dir , PrettyDirectory ):
722+ print (f"Storing the files under { self .remote_output_dir .directory } " )
723+ self .remote_output_dir = self .remote_output_dir .url
724+ else :
725+ print (f"Storing the files under { self .remote_output_dir } " )
707726
708727 self .random_seed = random_seed
709728
@@ -725,7 +744,7 @@ def run(self, data_recipe: DataRecipe) -> None:
725744 user_items : List [Any ] = data_recipe .prepare_structure (self .input_dir )
726745
727746 if not isinstance (user_items , list ):
728- raise ValueError ("The setup_fn should return a list of item metadata." )
747+ raise ValueError ("The `prepare_structure` should return a list of item metadata." )
729748
730749 # Associate the items to the workers based on num_nodes and node_rank
731750 begins , workers_user_items = _associated_items_to_workers (self .num_workers , user_items )
@@ -779,6 +798,8 @@ def run(self, data_recipe: DataRecipe) -> None:
779798 w .join (0 )
780799
781800 print ("Workers are finished." )
801+ if self .remote_output_dir :
802+ assert isinstance (self .remote_output_dir , str )
782803 data_recipe ._done (self .delete_cached_files , self .remote_output_dir )
783804 print ("Finished data processing!" )
784805
@@ -856,15 +877,15 @@ def _cleanup_cache(self) -> None:
856877
857878 # Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
858879 if os .path .exists (cache_dir ):
859- rmtree (cache_dir )
880+ rmtree (cache_dir , ignore_errors = True )
860881
861882 os .makedirs (cache_dir , exist_ok = True )
862883
863884 cache_data_dir = _get_cache_data_dir (self .name )
864885
865886 # Cleanup the cache data folder to avoid corrupted files from previous run to be there.
866887 if os .path .exists (cache_data_dir ):
867- rmtree (cache_data_dir )
888+ rmtree (cache_data_dir , ignore_errors = True )
868889
869890 os .makedirs (cache_data_dir , exist_ok = True )
870891
0 commit comments