1
+ import concurrent
1
2
import json
2
3
import logging
3
4
import os
27
28
_LIGHTNING_CLOUD_LATEST ,
28
29
_TORCH_GREATER_EQUAL_2_1_0 ,
29
30
)
31
+ from lightning .data .streaming .resolver import _resolve_dir
30
32
from lightning .data .utilities .broadcast import broadcast_object
31
33
from lightning .data .utilities .packing import _pack_greedily
32
34
35
37
36
38
if _LIGHTNING_CLOUD_LATEST :
37
39
from lightning_cloud .openapi import V1DatasetType
38
- from lightning_cloud .resolver import _resolve_dir
39
40
from lightning_cloud .utils .dataset import _create_dataset
40
41
41
42
@@ -120,7 +121,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
120
121
index , paths = r
121
122
122
123
# 5. Check whether all the files are already downloaded
123
- if all (os .path .exists (p .replace (input_dir .path , cache_dir ) if input_dir else p ) for p in paths ):
124
+ if input_dir .path and all (
125
+ os .path .exists (p .replace (input_dir .path , cache_dir ) if input_dir else p ) for p in paths
126
+ ):
124
127
queue_out .put (index )
125
128
continue
126
129
@@ -131,9 +134,10 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
131
134
132
135
# 7. Download all the required paths to unblock the current index
133
136
for path in paths :
134
- local_path = path .replace (input_dir .path , cache_dir )
137
+ if input_dir .path :
138
+ local_path = path .replace (input_dir .path , cache_dir )
135
139
136
- if input_dir .url :
140
+ if input_dir .url and input_dir . path :
137
141
path = path .replace (input_dir .path , input_dir .url )
138
142
139
143
obj = parse .urlparse (path )
@@ -168,7 +172,7 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None:
168
172
# 3. Iterate through the paths and delete them sequentially.
169
173
for path in paths :
170
174
if input_dir :
171
- if not path .startswith (cache_dir ):
175
+ if not path .startswith (cache_dir ) and input_dir . path is not None :
172
176
path = path .replace (input_dir .path , cache_dir )
173
177
174
178
if os .path .exists (path ):
@@ -199,11 +203,13 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
199
203
if obj .scheme == "s3" :
200
204
try :
201
205
s3 .client .upload_file (
202
- local_filepath , obj .netloc , os .path .join (obj .path .lstrip ("/" ), os .path .basename (local_filepath ))
206
+ local_filepath ,
207
+ obj .netloc ,
208
+ os .path .join (str (obj .path ).lstrip ("/" ), os .path .basename (local_filepath )),
203
209
)
204
210
except Exception as e :
205
211
print (e )
206
- elif os .path .isdir (output_dir .path ):
212
+ elif output_dir . path and os .path .isdir (output_dir .path ):
207
213
shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
208
214
else :
209
215
raise ValueError (f"The provided { output_dir .path } isn't supported." )
@@ -254,20 +260,30 @@ def _map_items_to_workers_weighted(
254
260
return [worker_items [worker_id ] for worker_id in worker_ids_this_node ]
255
261
256
262
263
+ def _get_num_bytes (item : Any , base_path : str ) -> int :
264
+ flattened_item , _ = tree_flatten (item )
265
+
266
+ num_bytes = 0
267
+ for element in flattened_item :
268
+ if isinstance (element , str ) and element .startswith (base_path ) and os .path .exists (element ):
269
+ file_bytes = os .path .getsize (element )
270
+ if file_bytes == 0 :
271
+ raise RuntimeError (f"The file { element } has 0 bytes!" )
272
+ num_bytes += file_bytes
273
+ return num_bytes
274
+
275
+
257
276
def _get_item_filesizes (items : List [Any ], base_path : str = "" ) -> List [int ]:
258
277
"""Computes the total size in bytes of all file paths for every datastructure in the given list."""
259
278
item_sizes = []
260
- for item in items :
261
- flattened_item , _ = tree_flatten (item )
262
-
263
- num_bytes = 0
264
- for element in flattened_item :
265
- if isinstance (element , str ) and element .startswith (base_path ) and os .path .exists (element ):
266
- file_bytes = os .path .getsize (element )
267
- if file_bytes == 0 :
268
- raise RuntimeError (f"The file { element } has 0 bytes!" )
269
- num_bytes += file_bytes
270
- item_sizes .append (num_bytes )
279
+
280
+ cpu_count = os .cpu_count () or 1
281
+
282
+ # Parallelize to accelerate retrieving the number of file bytes to read for each item
283
+ with concurrent .futures .ThreadPoolExecutor (max_workers = cpu_count * 2 if cpu_count > 4 else cpu_count ) as executor :
284
+ futures = [executor .submit (_get_num_bytes , item , base_path ) for item in items ]
285
+ for future in futures :
286
+ item_sizes .append (future .result ())
271
287
return item_sizes
272
288
273
289
@@ -358,7 +374,7 @@ def _loop(self) -> None:
358
374
for uploader in self .uploaders :
359
375
uploader .join ()
360
376
361
- if self .remove and self . input_dir . path is not None :
377
+ if self .remove :
362
378
assert self .remover
363
379
self .remove_queue .put (None )
364
380
self .remover .join ()
@@ -487,7 +503,7 @@ def _start_downloaders(self) -> None:
487
503
self .to_download_queues [downloader_index ].put (None )
488
504
489
505
def _start_remover (self ) -> None :
490
- if not self .remove or self . input_dir . path is None :
506
+ if not self .remove :
491
507
return
492
508
493
509
self .remover = Process (
@@ -696,9 +712,9 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
696
712
if obj .scheme == "s3" :
697
713
s3 = S3Client ()
698
714
s3 .client .upload_file (
699
- local_filepath , obj .netloc , os .path .join (obj .path .lstrip ("/" ), os .path .basename (local_filepath ))
715
+ local_filepath , obj .netloc , os .path .join (str ( obj .path ) .lstrip ("/" ), os .path .basename (local_filepath ))
700
716
)
701
- elif os .path .isdir (output_dir .path ):
717
+ elif output_dir . path and os .path .isdir (output_dir .path ):
702
718
shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
703
719
704
720
if num_nodes == 1 or node_rank is None :
@@ -710,16 +726,16 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
710
726
if num_nodes == node_rank + 1 :
711
727
# Get the index file locally
712
728
for node_rank in range (num_nodes - 1 ):
713
- remote_filepath = os . path . join (
714
- output_dir . url if output_dir . url else output_dir . path , f" { node_rank } - { _INDEX_FILENAME } "
715
- )
729
+ output_dir_path = output_dir . url if output_dir . url else output_dir . path
730
+ assert output_dir_path
731
+ remote_filepath = os . path . join ( output_dir_path , f" { node_rank } - { _INDEX_FILENAME } " )
716
732
node_index_filepath = os .path .join (cache_dir , os .path .basename (remote_filepath ))
717
733
if obj .scheme == "s3" :
718
734
obj = parse .urlparse (remote_filepath )
719
735
_wait_for_file_to_exist (s3 , obj )
720
736
with open (node_index_filepath , "wb" ) as f :
721
737
s3 .client .download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
722
- elif os .path .isdir (output_dir .path ):
738
+ elif output_dir . path and os .path .isdir (output_dir .path ):
723
739
shutil .copyfile (remote_filepath , node_index_filepath )
724
740
725
741
merge_cache = Cache (cache_dir , chunk_bytes = 1 )
0 commit comments