Skip to content

Commit a5bf301

Browse files
committed
Handle too high batch steps more graciously
Instead of erroring, when too many batchsteps is set such that the final batch size would exceed dataset length, simply don't truncate the batch steps instead of throwing an error. This change enables experimenting with more aggressive batch steps, and also comes in handy when working with long-read data.
1 parent 744ebda commit a5bf301

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

test/test_encode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_loss_falls(self):
174174
vae = vamb.encode.VAE(self.rpkm.shape[1])
175175
rpkm_copy = self.rpkm.copy()
176176
tnfs_copy = self.tnfs.copy()
177-
dl, mask = vamb.encode.make_dataloader(
177+
dl, _ = vamb.encode.make_dataloader(
178178
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
179179
)
180180
di = torch.Tensor(rpkm_copy)
@@ -202,10 +202,20 @@ def test_loss_falls(self):
202202
after_encoding = vae_2.encode(dl)
203203
self.assertTrue(np.all(np.abs(before_encoding - after_encoding) < 1e-6))
204204

205+
def test_warn_too_many_batch_steps(self):
206+
vae = vamb.encode.VAE(self.rpkm.shape[1])
207+
rpkm_copy = self.rpkm.copy()
208+
tnfs_copy = self.tnfs.copy()
209+
dl, _ = vamb.encode.make_dataloader(
210+
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
211+
)
212+
with self.assertWarns(Warning):
213+
vae.trainmodel(dl, nepochs=4, batchsteps=[1, 2, 3])
214+
205215
def test_encoding(self):
206216
nlatent = 15
207217
vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent)
208-
dl, mask = vamb.encode.make_dataloader(
218+
dl, _ = vamb.encode.make_dataloader(
209219
self.rpkm, self.tnfs, self.lens, batchsize=32
210220
)
211221
encoding = vae.encode(dl)

vamb/encode.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch import nn as _nn
1010
from math import log as _log
1111
from time import time
12+
import warnings
1213

1314
__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.
1415
@@ -379,7 +380,7 @@ def trainepoch(
379380
epoch_celoss = 0.0
380381

381382
if epoch in batchsteps:
382-
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2)
383+
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) # type: ignore
383384

384385
for depths_in, tnf_in, weights in data_loader:
385386
depths_in.requires_grad = True
@@ -450,7 +451,7 @@ def encode(self, data_loader) -> _np.ndarray:
450451

451452
row = 0
452453
with _torch.no_grad():
453-
for depths, tnf, weights in new_data_loader:
454+
for depths, tnf, _ in new_data_loader:
454455
# Move input to GPU if requested
455456
if self.usecuda:
456457
depths = depths.cuda()
@@ -551,28 +552,41 @@ def trainmodel(
551552
if nepochs < 1:
552553
raise ValueError("Minimum 1 epoch, not {nepochs}")
553554

554-
if batchsteps is None:
555-
batchsteps_set: set[int] = set()
555+
if batchsteps is None or len(batchsteps) == 0:
556+
sorted_batch_steps: list[int] = []
556557
else:
557558
# First collect to list in order to allow all element types, then check that
558559
# they are integers
559-
batchsteps = list(batchsteps)
560560
if not all(isinstance(i, int) for i in batchsteps):
561561
raise ValueError("All elements of batchsteps must be integers")
562-
if max(batchsteps, default=0) >= nepochs:
562+
sorted_batch_steps = sorted(set(batchsteps))
563+
if sorted_batch_steps[0] < 1:
564+
raise ValueError(
565+
f"Minimum of batchsteps must be 1, not {sorted_batch_steps[0]}"
566+
)
567+
if sorted_batch_steps[-1] >= nepochs:
563568
raise ValueError("Max batchsteps must not equal or exceed nepochs")
564-
last_batchsize = dataloader.batch_size * 2 ** len(batchsteps)
565-
if len(dataloader.dataset) < last_batchsize: # type: ignore
569+
570+
n_contigs = len(dataloader.dataset) # type: ignore
571+
starting_batch_size: int = dataloader.batch_size # type: ignore
572+
if n_contigs < starting_batch_size:
566573
raise ValueError(
567-
f"Last batch size of {last_batchsize} exceeds dataset length "
568-
f"of {len(dataloader.dataset)}. " # type: ignore
574+
f"Starting batch size of {starting_batch_size} exceeds dataset length "
575+
f"of {n_contigs}. "
569576
"This means you have too few contigs left after filtering to train. "
570577
"It is not adviced to run Vamb with fewer than 10,000 sequences "
571578
"after filtering. "
572579
"Please check the Vamb log file to see where the sequences were "
573580
"filtered away, and verify BAM files has sensible content."
574581
)
575-
batchsteps_set = set(batchsteps)
582+
maximum_batch_steps = (n_contigs // starting_batch_size).bit_length() - 1
583+
if maximum_batch_steps < len(sorted_batch_steps):
584+
warnings.warn(
585+
f"Requested {len(sorted_batch_steps)} batch steps, but with a starting "
586+
f"batch size of {starting_batch_size} and {n_contigs} contigs, "
587+
f"only the first {maximum_batch_steps} batch steps can be used."
588+
)
589+
sorted_batch_steps = sorted_batch_steps[:maximum_batch_steps]
576590

577591
# Get number of features
578592
# Following line is un-inferrable due to typing problems with DataLoader
@@ -591,8 +605,8 @@ def trainmodel(
591605
print("\tN epochs:", nepochs, file=logfile)
592606
print("\tStarting batch size:", dataloader.batch_size, file=logfile)
593607
batchsteps_string = (
594-
", ".join(map(str, sorted(batchsteps_set)))
595-
if batchsteps_set
608+
", ".join(map(str, sorted_batch_steps))
609+
if len(sorted_batch_steps) > 0
596610
else "None"
597611
)
598612
print("\tBatchsteps:", batchsteps_string, file=logfile)
@@ -603,7 +617,7 @@ def trainmodel(
603617
# Train
604618
for epoch in range(nepochs):
605619
dataloader = self.trainepoch(
606-
dataloader, epoch, optimizer, sorted(batchsteps_set), time(), logfile
620+
dataloader, epoch, optimizer, sorted_batch_steps, time(), logfile
607621
)
608622

609623
# Save weights - Lord forgive me, for I have sinned when catching all exceptions

0 commit comments

Comments
 (0)