Skip to content

Commit b6bc2ff

Browse files
committed
BUG: remove printing.
1 parent e6a2eb2 commit b6bc2ff

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

barcodebert/datasets.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from transformers import AutoTokenizer
1414
import torch.distributed as dist
1515
from torch.utils.data import get_worker_info
16+
import math
1617

1718

1819
class KmerTokenizer(object):
@@ -154,12 +155,32 @@ def parse_row(self, row):
154155
return tokens, torch.tensor(int(label), dtype=torch.int64), att_mask
155156

156157
def __len__(self):
157-
# count lines once at startup (cheap) so val‑throughput logging doesn't crash
158-
if not hasattr(self, "_n_samples"):
158+
# 1) count total lines in file (once)
159+
if not hasattr(self, "_n_total"):
159160
with open(self.file_path, "r") as f:
160-
# subtract header if CSV has one; adjust accordingly
161-
self._n_samples = sum(1 for _ in f) - 1
162-
return self._n_samples
161+
# subtract 1 if there’s a header row
162+
self._n_total = sum(1 for _ in f) - 1
163+
164+
# 2) figure out how many each *process* + *worker* should see
165+
if dist.is_available() and dist.is_initialized():
166+
rank = dist.get_rank()
167+
world_size = dist.get_world_size()
168+
else:
169+
rank, world_size = 0, 1
170+
171+
worker_info = get_worker_info()
172+
if worker_info is not None:
173+
# subdivide further across DataLoader workers
174+
total_workers = worker_info.num_workers
175+
rank = rank * total_workers + worker_info.id
176+
world_size = world_size * total_workers
177+
178+
# 3) even‐split the total lines, giving the first (rem) ranks one extra
179+
base, rem = divmod(self._n_total, world_size)
180+
local_n = base + (1 if rank < rem else 0)
181+
182+
# 4) batches per epoch for *this* process
183+
return math.ceil(local_n / self._batch_size_per_gpu)
163184

164185

165186
def __iter__(self):

0 commit comments

Comments
 (0)