Skip to content

Commit 4946f55

Browse files
committed
BUG: fix bugs in lazy loader.
1 parent 3aedf16 commit 4946f55

File tree

2 files changed

+75
-41
lines changed

2 files changed

+75
-41
lines changed

barcodebert/datasets.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch.utils.data import Dataset, IterableDataset
1212
from torchtext.vocab import vocab as build_vocab_from_dict
1313
from transformers import AutoTokenizer
14+
import torch.distributed as dist
15+
from torch.utils.data import get_worker_info
1416

1517

1618
class KmerTokenizer(object):
@@ -151,14 +153,44 @@ def parse_row(self, row):
151153
tokens, att_mask = self.tokenizer(dna_seq, offset=offset)
152154
return tokens, torch.tensor(int(label), dtype=torch.int64), att_mask
153155

156+
def __len__(self):
157+
# count lines once at startup (cheap) so val‑throughput logging doesn't crash
158+
if not hasattr(self, "_n_samples"):
159+
with open(self.file_path, "r") as f:
160+
# subtract header if CSV has one; adjust accordingly
161+
self._n_samples = sum(1 for _ in f) - 1
162+
return self._n_samples
163+
164+
154165
def __iter__(self):
166+
# Determine global rank & world size
167+
if dist.is_available() and dist.is_initialized():
168+
rank = dist.get_rank()
169+
world_size = dist.get_world_size()
170+
else:
171+
rank, world_size = 0, 1
172+
173+
# If we're also using multiple DataLoader workers (num_workers > 1),
174+
# further subdivide per-worker:
175+
worker_info = get_worker_info()
176+
if worker_info is not None:
177+
# each worker in the same process gets a unique ID
178+
worker_id = worker_info.id
179+
total_workers = worker_info.num_workers
180+
# flatten ranks+workers into a single shard index
181+
rank = rank * total_workers + worker_id
182+
world_size = world_size * total_workers
183+
184+
# Now stream the file, and only yield rows where idx % world_size == rank
155185
df_iter = pd.read_csv(
156186
self.file_path,
157187
sep="\t" if self.file_path.endswith(".tsv") else ",",
158188
chunksize=1,
159189
keep_default_na=False,
160190
)
161-
for chunk in df_iter:
191+
for idx, chunk in enumerate(df_iter):
192+
if idx % world_size != rank:
193+
continue
162194
yield self.parse_row(chunk.iloc[0])
163195

164196

barcodebert/pretraining.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -181,46 +181,48 @@ def print_pass(*args, **kwargs):
181181

182182
eval_set = "Val"
183183

184-
# Dataloader --------------------------------------------------------------
185-
dl_train_kwargs = {
186-
"batch_size": config.batch_size_per_gpu,
187-
"drop_last": True,
188-
"sampler": None,
189-
"shuffle": True,
190-
"worker_init_fn": utils.worker_seed_fn,
191-
}
192-
dl_val_kwargs = {
193-
"batch_size": config.batch_size_per_gpu,
194-
"drop_last": False,
195-
"sampler": None,
196-
"shuffle": False,
197-
"worker_init_fn": utils.worker_seed_fn,
198-
}
199-
if config.cpu_workers is None:
200-
config.cpu_workers = utils.get_num_cpu_available()
201-
if use_cuda:
202-
cuda_kwargs = {"num_workers": config.cpu_workers, "pin_memory": True}
203-
dl_train_kwargs.update(cuda_kwargs)
204-
dl_val_kwargs.update(cuda_kwargs)
205-
206-
if config.distributed:
207-
# The DistributedSampler breaks up the dataset across the GPUs
208-
dl_train_kwargs["sampler"] = DistributedSampler(
209-
dataset_train,
210-
shuffle=True,
211-
seed=config.seed if config.seed is not None else 0,
212-
drop_last=False,
213-
)
214-
dl_train_kwargs["shuffle"] = None
215-
dl_val_kwargs["sampler"] = DistributedSampler(
216-
dataset_val,
217-
shuffle=False,
218-
drop_last=False,
219-
)
220-
dl_val_kwargs["shuffle"] = None
221-
222-
dataloader_train = torch.utils.data.DataLoader(dataset_train, **dl_train_kwargs)
223-
dataloader_val = torch.utils.data.DataLoader(dataset_val, **dl_val_kwargs)
184+
# Dataloaders -------------------------------------------------------------
185+
if config.lazy_load:
186+
# streaming IterableDataset → no sampler, no shuffle
187+
stream_kwargs = {
188+
"batch_size": config.batch_size_per_gpu,
189+
"drop_last": True,
190+
"num_workers": config.cpu_workers,
191+
"pin_memory": use_cuda,
192+
"worker_init_fn": utils.worker_seed_fn,
193+
}
194+
dataloader_train = torch.utils.data.DataLoader(dataset_train, **stream_kwargs)
195+
dataloader_val = torch.utils.data.DataLoader(dataset_val, **stream_kwargs)
196+
else:
197+
# map‑style Dataset → use DistributedSampler in dist. mode
198+
map_train_kwargs = {
199+
"batch_size": config.batch_size_per_gpu,
200+
"drop_last": True,
201+
"shuffle": True,
202+
"num_workers": config.cpu_workers,
203+
"pin_memory": use_cuda,
204+
"worker_init_fn": utils.worker_seed_fn,
205+
}
206+
map_val_kwargs = {
207+
"batch_size": config.batch_size_per_gpu,
208+
"drop_last": False,
209+
"shuffle": False,
210+
"num_workers": config.cpu_workers,
211+
"pin_memory": use_cuda,
212+
"worker_init_fn": utils.worker_seed_fn,
213+
}
214+
if config.distributed:
215+
map_train_kwargs["shuffle"] = False
216+
map_train_kwargs["sampler"] = DistributedSampler(
217+
dataset_train, shuffle=True,
218+
seed=(config.seed or 0),
219+
drop_last=False,
220+
)
221+
map_val_kwargs["sampler"] = DistributedSampler(
222+
dataset_val, shuffle=False, drop_last=False
223+
)
224+
dataloader_train = torch.utils.data.DataLoader(dataset_train, **map_train_kwargs)
225+
dataloader_val = torch.utils.data.DataLoader(dataset_val, **map_val_kwargs)
224226

225227
# MODEL ===================================================================
226228
base_pairs = "ACGT"

0 commit comments

Comments
 (0)