|
13 | 13 | from transformers import AutoTokenizer |
14 | 14 | import torch.distributed as dist |
15 | 15 | from torch.utils.data import get_worker_info |
| 16 | +import math |
16 | 17 |
|
17 | 18 |
|
18 | 19 | class KmerTokenizer(object): |
@@ -154,12 +155,32 @@ def parse_row(self, row): |
154 | 155 | return tokens, torch.tensor(int(label), dtype=torch.int64), att_mask |
155 | 156 |
|
156 | 157 | def __len__(self): |
157 | | - # count lines once at startup (cheap) so val‑throughput logging doesn't crash |
158 | | - if not hasattr(self, "_n_samples"): |
| 158 | + # 1) count total lines in file (once) |
| 159 | + if not hasattr(self, "_n_total"): |
159 | 160 | with open(self.file_path, "r") as f: |
160 | | - # subtract header if CSV has one; adjust accordingly |
161 | | - self._n_samples = sum(1 for _ in f) - 1 |
162 | | - return self._n_samples |
| 161 | + # subtract 1 if there’s a header row |
| 162 | + self._n_total = sum(1 for _ in f) - 1 |
| 163 | + |
| 164 | + # 2) figure out how many each *process* + *worker* should see |
| 165 | + if dist.is_available() and dist.is_initialized(): |
| 166 | + rank = dist.get_rank() |
| 167 | + world_size = dist.get_world_size() |
| 168 | + else: |
| 169 | + rank, world_size = 0, 1 |
| 170 | + |
| 171 | + worker_info = get_worker_info() |
| 172 | + if worker_info is not None: |
| 173 | + # subdivide further across DataLoader workers |
| 174 | + total_workers = worker_info.num_workers |
| 175 | + rank = rank * total_workers + worker_info.id |
| 176 | + world_size = world_size * total_workers |
| 177 | + |
| 178 | + # 3) even‐split the total lines, giving the first (rem) ranks one extra |
| 179 | + base, rem = divmod(self._n_total, world_size) |
| 180 | + local_n = base + (1 if rank < rem else 0) |
| 181 | + |
| 182 | + # 4) batches per epoch for *this* process |
| 183 | + return math.ceil(local_n / self._batch_size_per_gpu) |
163 | 184 |
|
164 | 185 |
|
165 | 186 | def __iter__(self): |
|
0 commit comments