1818from logging import Logger
1919from queue import Empty
2020from threading import Thread
21- from time import sleep
2221from typing import Any , Dict , List , Optional , Tuple , Union
2322
2423from lightning .data .streaming .config import ChunksConfig
3029
3130warnings .filterwarnings ("ignore" , message = ".*The given buffer is not writable.*" )
3231
33-
3432if _TORCH_GREATER_EQUAL_2_1_0 :
3533 pass
3634
3735
3836logger = Logger (__name__ )
3937
4038
39+ _END_TOKEN = "END"
40+
41+ # Note: The timeout here should not be too short. We need to prevent the caller from aggressively
42+ # querying the queue and consuming too many CPU cycles.
43+ _DEFAULT_TIMEOUT = 0.1
44+ _LONG_DEFAULT_TIMEOUT = 5
45+
46+
4147class PrepareChunksThread (Thread ):
4248 """This thread is responsible to download the chunks associated to a given worker."""
4349
@@ -59,22 +65,7 @@ def __init__(
5965 self ._parent_cache_dir = os .path .dirname (self ._config ._cache_dir )
6066 self ._to_download_queue : multiprocessing .Queue = multiprocessing .Queue ()
6167 self ._to_delete_queue : multiprocessing .Queue = multiprocessing .Queue ()
62- self ._to_stop_queue : multiprocessing .Queue = multiprocessing .Queue ()
63-
64- # populate back the queues with existing items. As they already exists, this is almost a no-op
65- for chunk_index in self ._collect_ordered_chunk_indexes_from_cache ():
66- self ._to_download_queue .put (chunk_index )
67- self ._to_delete_queue .put (chunk_index )
68-
69- def _collect_ordered_chunk_indexes_from_cache (self ) -> List [int ]:
70- """List the chunks available in the cache, order them based on their creation time and retrieves their
71- indexes."""
72- chunk_indexes = [
73- [self ._config ._get_chunk_index_from_filename (f ), os .path .getctime (os .path .join (self ._config ._cache_dir , f ))]
74- for f in os .listdir (self ._config ._cache_dir )
75- if f .endswith (".bin" )
76- ]
77- return [int (x [0 ]) for x in sorted (chunk_indexes , key = lambda x : x [1 ])]
68+ self ._delete_chunks_when_processed = self ._config .num_bytes > max_cache_size if max_cache_size else False
7869
7970 def download (self , chunk_indexes : List [int ]) -> None :
8071 """Receive the list of the chunk indices to download for the current epoch."""
@@ -93,10 +84,15 @@ def _delete(self, chunk_index: int) -> None:
9384
9485 def stop (self ) -> None :
9586 """Receive the list of the chunk indices to download for the current epoch."""
96- self ._to_stop_queue .put (True )
87+ self ._to_download_queue .put (_END_TOKEN )
9788
9889 def _maybe_delete_chunks (self ) -> None :
99- chunk_index = _get_from_queue (self ._to_delete_queue )
90+ reached_pre_download = self ._pre_download_counter == self ._max_pre_download
91+
92+ # we have already pre-downloaded some chunks, we just need to wait for them to be processed.
93+ chunk_index = _get_from_queue (
94+ self ._to_delete_queue , timeout = _LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT
95+ )
10096
10197 if chunk_index is not None :
10298 self ._pre_download_counter -= 1
@@ -105,14 +101,17 @@ def _maybe_delete_chunks(self) -> None:
105101 self ._chunks_index_to_be_deleted .append (chunk_index )
106102
107103 # Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it
108- while (
109- self ._max_cache_size
110- and self ._chunks_index_to_be_deleted
111- and _get_folder_size (self ._parent_cache_dir ) >= self ._max_cache_size
112- ):
104+ while self ._max_cache_size and self ._chunks_index_to_be_deleted and self ._can_delete_chunk ():
113105 # Delete the oldest chunk
114106 self ._delete (self ._chunks_index_to_be_deleted .pop (0 ))
115107
108+ return
109+
110+ def _can_delete_chunk (self ) -> bool :
111+ if self ._delete_chunks_when_processed :
112+ return self ._pre_download_counter == self ._max_pre_download - 1
113+ return self ._max_cache_size is not None and _get_folder_size (self ._parent_cache_dir ) >= self ._max_cache_size
114+
116115 def _pre_load_chunk (self , chunk_index : int ) -> None :
117116 chunk_filepath , _ , _ = self ._config [ChunkedIndex (index = - 1 , chunk_index = chunk_index )]
118117 self ._item_loader .pre_load_chunk (chunk_index , chunk_filepath )
@@ -121,6 +120,9 @@ def run(self) -> None:
121120 while True :
122121 if self ._pre_download_counter <= self ._max_pre_download :
123122 chunk_index = _get_from_queue (self ._to_download_queue )
123+ if chunk_index == _END_TOKEN :
124+ return
125+
124126 if chunk_index is not None :
125127 self ._config .download_chunk_from_index (chunk_index )
126128
@@ -135,11 +137,6 @@ def run(self) -> None:
135137 if self ._max_cache_size :
136138 self ._maybe_delete_chunks ()
137139
138- if _get_from_queue (self ._to_stop_queue ):
139- return
140-
141- sleep (0.05 )
142-
143140
144141class BinaryReader :
145142 def __init__ (
@@ -238,6 +235,9 @@ def read(self, index: ChunkedIndex) -> Any:
238235 assert self ._prepare_thread
239236 self ._prepare_thread .download ([index .chunk_index ])
240237
238+ if self ._last_chunk_index is None :
239+ self ._last_chunk_index = index .chunk_index
240+
241241 # Fetch the element
242242 chunk_filepath , begin , _ = self .config [index ]
243243 item = self ._item_loader .load_item_from_chunk (index .index , index .chunk_index , chunk_filepath , begin )
@@ -246,9 +246,10 @@ def read(self, index: ChunkedIndex) -> Any:
246246 # Otherwise, this could trigger segmentation fault error depending on the item loader used.
247247 if self ._config and self ._config ._remote_dir and index .chunk_index != self ._last_chunk_index :
248248 assert self ._prepare_thread
249- if self ._last_chunk_index is not None :
250- # inform the chunk has been completely consumed
251- self ._prepare_thread .delete ([self ._last_chunk_index ])
249+ assert self ._last_chunk_index is not None
250+
251+ # inform the chunk has been completely consumed
252+ self ._prepare_thread .delete ([self ._last_chunk_index ])
252253
253254 # track the new chunk index as the latest one
254255 self ._last_chunk_index = index .chunk_index
@@ -294,11 +295,9 @@ def _get_folder_size(path: str) -> int:
294295 return size
295296
296297
297- def _get_from_queue (queue : multiprocessing .Queue ) -> Optional [Any ]:
298+ def _get_from_queue (queue : multiprocessing .Queue , timeout : float = _DEFAULT_TIMEOUT ) -> Optional [Any ]:
298299 try :
299- # Note: The timeout here should not be too short. We need to prevent the caller from aggressively
300- # querying the queue and consuming too many CPU cycles.
301- return queue .get (timeout = 0.1 )
300+ return queue .get (timeout = timeout )
302301 except Empty :
303302 pass
304303 except OSError as e :
0 commit comments