@@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
87
87
return os .path .join (cache_dir , name .lstrip ("/" ))
88
88
89
89
90
- def _wait_for_file_to_exist (s3 : Any , obj : parse .ParseResult , sleep_time : int = 2 ) -> Any :
90
+ def _wait_for_file_to_exist (s3 : S3Client , obj : parse .ParseResult , sleep_time : int = 2 ) -> Any :
91
91
"""This function check."""
92
92
while True :
93
93
try :
94
- return s3 .head_object (Bucket = obj .netloc , Key = obj .path .lstrip ("/" ))
94
+ return s3 .client . head_object (Bucket = obj .netloc , Key = obj .path .lstrip ("/" ))
95
95
except botocore .exceptions .ClientError as e :
96
96
if "the HeadObject operation: Not Found" in str (e ):
97
97
sleep (sleep_time )
@@ -659,7 +659,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
659
659
obj = parse .urlparse (remote_filepath )
660
660
_wait_for_file_to_exist (s3 , obj )
661
661
with open (node_index_filepath , "wb" ) as f :
662
- s3 .download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
662
+ s3 .client . download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
663
663
elif os .path .isdir (output_dir .path ):
664
664
copyfile (remote_filepath , node_index_filepath )
665
665
@@ -799,15 +799,16 @@ def run(self, data_recipe: DataRecipe) -> None:
799
799
break
800
800
801
801
num_nodes = _get_num_nodes ()
802
+ node_rank = _get_node_rank ()
802
803
# TODO: Understand why it hangs.
803
804
if num_nodes == 1 :
804
805
for w in self .workers :
805
806
w .join (0 )
806
807
807
808
print ("Workers are finished." )
808
- result = data_recipe ._done (num_items , self .delete_cached_files , self .output_dir )
809
+ result = data_recipe ._done (len ( user_items ) , self .delete_cached_files , self .output_dir )
809
810
810
- if num_nodes == _get_node_rank () + 1 :
811
+ if num_nodes == node_rank + 1 :
811
812
_create_dataset (
812
813
input_dir = self .input_dir .path ,
813
814
storage_dir = self .output_dir .path ,
0 commit comments