diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 9de80f32ab..63de366db6 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -169,3 +169,6 @@ def read(self) -> List: raise NotImplementedError( "read() is not implemented for FileWrapper, please use QUEUE instead" ) + + def finish(self) -> None: + self.file.close() diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 9d32d4ba04..5eec9c4464 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -5,6 +5,7 @@ import datasets import transformers from datasets import Dataset, load_dataset +from tqdm import tqdm from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig @@ -34,11 +35,20 @@ def __init__(self, dataset: Dataset, max_epoch: int = 1, offset: int = 0): for _ in range(self.current_offset): next(self.iter) + # Initialize tqdm progress bar + self.total_steps = self.dataset_size * self.max_epoch + self.progress_bar = tqdm( + total=self.total_steps, + initial=self.current_epoch * self.dataset_size + self.current_offset, + desc="Dataset Progressing", + ) + def read_batch(self, batch_size: int) -> List: batch = [] while len(batch) < batch_size: try: + self.progress_bar.update(1) item = next(self.iter) batch.append(item) self.current_offset += 1 @@ -48,7 +58,9 @@ def read_batch(self, batch_size: int) -> List: self.current_offset = 0 if self.current_epoch >= self.max_epoch: + self.progress_bar.close() raise StopIteration + # Step to the next epoch self.iter = iter(self.dataset) return batch