1+ import json
12import logging
23import os
34import signal
45import tempfile
56import traceback
67import types
78from abc import abstractmethod
9+ from dataclasses import dataclass
810from multiprocessing import Process , Queue
911from queue import Empty
1012from shutil import copyfile , rmtree
2325 _BOTO3_AVAILABLE ,
2426 _DEFAULT_FAST_DEV_RUN_ITEMS ,
2527 _INDEX_FILENAME ,
26- _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 ,
28+ _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50 ,
2729 _TORCH_GREATER_EQUAL_2_1_0 ,
2830)
2931from lightning .fabric .accelerators .cuda import is_cuda_available
3537from lightning .fabric .utilities .distributed import group as _group
3638
3739if _TORCH_GREATER_EQUAL_2_1_0 :
38- from torch .utils ._pytree import tree_flatten , tree_unflatten
40+ from torch .utils ._pytree import tree_flatten , tree_unflatten , treespec_loads
3941
40- if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 :
42+ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50 :
43+ from lightning_cloud .openapi import V1DatasetType
4144 from lightning_cloud .resolver import _resolve_dir
45+ from lightning_cloud .utils .dataset import _create_dataset
4246
4347
4448if _BOTO3_AVAILABLE :
@@ -191,7 +195,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
191195 )
192196 except Exception as e :
193197 print (e )
194- if os .path .isdir (output_dir .path ):
198+ elif os .path .isdir (output_dir .path ):
195199 copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
196200 else :
197201 raise ValueError (f"The provided { output_dir .path } isn't supported." )
@@ -506,6 +510,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
506510 Process .__init__ (self )
507511
508512
513+ @dataclass
514+ class _Result :
515+ size : Optional [int ] = None
516+ num_bytes : Optional [str ] = None
517+ data_format : Optional [str ] = None
518+ compression : Optional [str ] = None
519+ num_chunks : Optional [int ] = None
520+ num_bytes_per_chunk : Optional [List [int ]] = None
521+
522+
509523T = TypeVar ("T" )
510524
511525
@@ -545,8 +559,8 @@ def listdir(self, path: str) -> List[str]:
545559 def __init__ (self ) -> None :
546560 self ._name : Optional [str ] = None
547561
548- def _done (self , delete_cached_files : bool , output_dir : Dir ) -> None :
549- pass
562+ def _done (self , size : int , delete_cached_files : bool , output_dir : Dir ) -> _Result :
563+ return _Result ( size = size )
550564
551565
552566class DataChunkRecipe (DataRecipe ):
@@ -576,7 +590,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
576590 def prepare_item (self , item_metadata : T ) -> Any : # type: ignore
577591 """The return of this `prepare_item` method is persisted in chunked binary files."""
578592
579- def _done (self , delete_cached_files : bool , output_dir : Dir ) -> None :
593+ def _done (self , size : int , delete_cached_files : bool , output_dir : Dir ) -> _Result :
580594 num_nodes = _get_num_nodes ()
581595 cache_dir = _get_cache_dir ()
582596
@@ -589,6 +603,26 @@ def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
589603 merge_cache ._merge_no_wait (node_rank if num_nodes > 1 else None )
590604 self ._upload_index (output_dir , cache_dir , num_nodes , node_rank )
591605
606+ if num_nodes == node_rank + 1 :
607+ with open (os .path .join (cache_dir , _INDEX_FILENAME )) as f :
608+ config = json .load (f )
609+
610+ size = sum ([c ["dim" ] if c ["dim" ] is not None else c ["chunk_size" ] for c in config ["chunks" ]])
611+ num_bytes = sum ([c ["chunk_bytes" ] for c in config ["chunks" ]])
612+ data_format = tree_unflatten (config ["config" ]["data_format" ], treespec_loads (config ["config" ]["data_spec" ]))
613+
614+ return _Result (
615+ size = size ,
616+ num_bytes = num_bytes ,
617+ data_format = data_format ,
618+ compression = config ["config" ]["compression" ],
619+ num_chunks = len (config ["chunks" ]),
620+ num_bytes_per_chunk = [c ["chunk_size" ] for c in config ["chunks" ]],
621+ )
622+ return _Result (
623+ size = size ,
624+ )
625+
592626 def _upload_index (self , output_dir : Dir , cache_dir : str , num_nodes : int , node_rank : Optional [int ]) -> None :
593627 """This method upload the index file to the remote cloud directory."""
594628 if output_dir .path is None and output_dir .url is None :
@@ -764,13 +798,31 @@ def run(self, data_recipe: DataRecipe) -> None:
764798 has_failed = True
765799 break
766800
801+ num_nodes = _get_num_nodes ()
767802 # TODO: Understand why it hangs.
768- if _get_num_nodes () == 1 :
803+ if num_nodes == 1 :
769804 for w in self .workers :
770805 w .join (0 )
771806
772807 print ("Workers are finished." )
773- data_recipe ._done (self .delete_cached_files , self .output_dir )
808+ result = data_recipe ._done (num_items , self .delete_cached_files , self .output_dir )
809+
810+ if num_nodes == _get_node_rank () + 1 :
811+ _create_dataset (
812+ input_dir = self .input_dir .path ,
813+ storage_dir = self .output_dir .path ,
814+ dataset_type = V1DatasetType .CHUNKED
815+ if isinstance (data_recipe , DataChunkRecipe )
816+ else V1DatasetType .TRANSFORMED ,
817+ empty = False ,
818+ size = result .size ,
819+ num_bytes = result .num_bytes ,
820+ data_format = result .data_format ,
821+ compression = result .compression ,
822+ num_chunks = result .num_chunks ,
823+ num_bytes_per_chunk = result .num_bytes_per_chunk ,
824+ )
825+
774826 print ("Finished data processing!" )
775827
776828 # TODO: Understand why it is required to avoid long shutdown.
0 commit comments