Skip to content

Commit 6df6162

Browse files
tchatonthomas
authored andcommitted
lightning.data: Fix some bugs with optimize (#18949)
Co-authored-by: thomas <[email protected]> (cherry picked from commit 3a86097)
1 parent ab76989 commit 6df6162

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
8787
return os.path.join(cache_dir, name.lstrip("/"))
8888

8989

90-
def _wait_for_file_to_exist(s3: Any, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
90+
def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
9191
"""This function check."""
9292
while True:
9393
try:
94-
return s3.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
94+
return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
9595
except botocore.exceptions.ClientError as e:
9696
if "the HeadObject operation: Not Found" in str(e):
9797
sleep(sleep_time)
@@ -659,7 +659,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
659659
obj = parse.urlparse(remote_filepath)
660660
_wait_for_file_to_exist(s3, obj)
661661
with open(node_index_filepath, "wb") as f:
662-
s3.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
662+
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
663663
elif os.path.isdir(output_dir.path):
664664
copyfile(remote_filepath, node_index_filepath)
665665

@@ -799,15 +799,16 @@ def run(self, data_recipe: DataRecipe) -> None:
799799
break
800800

801801
num_nodes = _get_num_nodes()
802+
node_rank = _get_node_rank()
802803
# TODO: Understand why it hangs.
803804
if num_nodes == 1:
804805
for w in self.workers:
805806
w.join(0)
806807

807808
print("Workers are finished.")
808-
result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir)
809+
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
809810

810-
if num_nodes == _get_node_rank() + 1:
811+
if num_nodes == node_rank + 1:
811812
_create_dataset(
812813
input_dir=self.input_dir.path,
813814
storage_dir=self.output_dir.path,

tests/tests_data/streaming/test_data_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def fn(*_, **__):
204204
raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
205205
return
206206

207-
s3.head_object = fn
207+
s3.client.head_object = fn
208208

209209
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
210210

@@ -213,7 +213,7 @@ def fn(*_, **__):
213213
def fn(*_, **__):
214214
raise ValueError("HERE")
215215

216-
s3.head_object = fn
216+
s3.client.head_object = fn
217217

218218
with pytest.raises(ValueError, match="HERE"):
219219
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)

0 commit comments

Comments
 (0)