@@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
8787 return os .path .join (cache_dir , name .lstrip ("/" ))
8888
8989
90- def _wait_for_file_to_exist (s3 : Any , obj : parse .ParseResult , sleep_time : int = 2 ) -> Any :
90+ def _wait_for_file_to_exist (s3 : S3Client , obj : parse .ParseResult , sleep_time : int = 2 ) -> Any :
9191 """This function check."""
9292 while True :
9393 try :
94- return s3 .head_object (Bucket = obj .netloc , Key = obj .path .lstrip ("/" ))
94+ return s3 .client . head_object (Bucket = obj .netloc , Key = obj .path .lstrip ("/" ))
9595 except botocore .exceptions .ClientError as e :
9696 if "the HeadObject operation: Not Found" in str (e ):
9797 sleep (sleep_time )
@@ -659,7 +659,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
659659 obj = parse .urlparse (remote_filepath )
660660 _wait_for_file_to_exist (s3 , obj )
661661 with open (node_index_filepath , "wb" ) as f :
662- s3 .download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
662+ s3 .client . download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
663663 elif os .path .isdir (output_dir .path ):
664664 copyfile (remote_filepath , node_index_filepath )
665665
@@ -799,15 +799,16 @@ def run(self, data_recipe: DataRecipe) -> None:
799799 break
800800
801801 num_nodes = _get_num_nodes ()
802+ node_rank = _get_node_rank ()
802803 # TODO: Understand why it hangs.
803804 if num_nodes == 1 :
804805 for w in self .workers :
805806 w .join (0 )
806807
807808 print ("Workers are finished." )
808- result = data_recipe ._done (num_items , self .delete_cached_files , self .output_dir )
809+ result = data_recipe ._done (len ( user_items ) , self .delete_cached_files , self .output_dir )
809810
810- if num_nodes == _get_node_rank () + 1 :
811+ if num_nodes == node_rank + 1 :
811812 _create_dataset (
812813 input_dir = self .input_dir .path ,
813814 storage_dir = self .output_dir .path ,
0 commit comments