Skip to content

Commit 079201b

Browse files
shiweijiezeroweijie
andauthored
Progress Bar (#87)
Co-authored-by: weijie <[email protected]>
1 parent ad77ffe commit 079201b

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

trinity/buffer/ray_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,6 @@ def read(self) -> List:
169169
raise NotImplementedError(
170170
"read() is not implemented for FileWrapper, please use QUEUE instead"
171171
)
172+
173+
def finish(self) -> None:
174+
self.file.close()

trinity/buffer/reader/file_reader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import datasets
66
import transformers
77
from datasets import Dataset, load_dataset
8+
from tqdm import tqdm
89

910
from trinity.buffer.buffer_reader import BufferReader
1011
from trinity.common.config import BufferConfig, StorageConfig
@@ -34,11 +35,20 @@ def __init__(self, dataset: Dataset, max_epoch: int = 1, offset: int = 0):
3435
for _ in range(self.current_offset):
3536
next(self.iter)
3637

38+
# Initialize tqdm progress bar
39+
self.total_steps = self.dataset_size * self.max_epoch
40+
self.progress_bar = tqdm(
41+
total=self.total_steps,
42+
initial=self.current_epoch * self.dataset_size + self.current_offset,
43+
desc="Dataset Progressing",
44+
)
45+
3746
def read_batch(self, batch_size: int) -> List:
3847
batch = []
3948

4049
while len(batch) < batch_size:
4150
try:
51+
self.progress_bar.update(1)
4252
item = next(self.iter)
4353
batch.append(item)
4454
self.current_offset += 1
@@ -48,7 +58,9 @@ def read_batch(self, batch_size: int) -> List:
4858
self.current_offset = 0
4959

5060
if self.current_epoch >= self.max_epoch:
61+
self.progress_bar.close()
5162
raise StopIteration
63+
# Step to the next epoch
5264
self.iter = iter(self.dataset)
5365
return batch
5466

0 commit comments

Comments
 (0)