99from torch import nn as _nn
1010from math import log as _log
1111from time import time
12+ import warnings
1213
1314__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.
1415
@@ -379,7 +380,7 @@ def trainepoch(
379380 epoch_celoss = 0.0
380381
381382 if epoch in batchsteps :
382- data_loader = set_batchsize (data_loader , data_loader .batch_size * 2 )
383+ data_loader = set_batchsize (data_loader , data_loader .batch_size * 2 ) # type: ignore
383384
384385 for depths_in , tnf_in , weights in data_loader :
385386 depths_in .requires_grad = True
@@ -450,7 +451,7 @@ def encode(self, data_loader) -> _np.ndarray:
450451
451452 row = 0
452453 with _torch .no_grad ():
453- for depths , tnf , weights in new_data_loader :
454+ for depths , tnf , _ in new_data_loader :
454455 # Move input to GPU if requested
455456 if self .usecuda :
456457 depths = depths .cuda ()
@@ -551,28 +552,41 @@ def trainmodel(
551552 if nepochs < 1 :
552553 raise ValueError ("Minimum 1 epoch, not {nepochs}" )
553554
554- if batchsteps is None :
555- batchsteps_set : set [int ] = set ()
555+ if batchsteps is None or len ( batchsteps ) == 0 :
556+ sorted_batch_steps : list [int ] = []
556557 else :
557558 # First collect to list in order to allow all element types, then check that
558559 # they are integers
559- batchsteps = list (batchsteps )
560560 if not all (isinstance (i , int ) for i in batchsteps ):
561561 raise ValueError ("All elements of batchsteps must be integers" )
562- if max (batchsteps , default = 0 ) >= nepochs :
562+ sorted_batch_steps = sorted (set (batchsteps ))
563+ if sorted_batch_steps [0 ] < 1 :
564+ raise ValueError (
565+ f"Minimum of batchsteps must be 1, not { sorted_batch_steps [0 ]} "
566+ )
567+ if sorted_batch_steps [- 1 ] >= nepochs :
563568 raise ValueError ("Max batchsteps must not equal or exceed nepochs" )
564- last_batchsize = dataloader .batch_size * 2 ** len (batchsteps )
565- if len (dataloader .dataset ) < last_batchsize : # type: ignore
569+
570+ n_contigs = len (dataloader .dataset ) # type: ignore
571+ starting_batch_size : int = dataloader .batch_size # type: ignore
572+ if n_contigs < starting_batch_size :
566573 raise ValueError (
567- f"Last batch size of { last_batchsize } exceeds dataset length "
568- f"of { len ( dataloader . dataset ) } . " # type: ignore
574+ f"Starting batch size of { starting_batch_size } exceeds dataset length "
575+ f"of { n_contigs } . "
569576 "This means you have too few contigs left after filtering to train. "
570577 "It is not adviced to run Vamb with fewer than 10,000 sequences "
571578 "after filtering. "
572579 "Please check the Vamb log file to see where the sequences were "
573580 "filtered away, and verify BAM files has sensible content."
574581 )
575- batchsteps_set = set (batchsteps )
582+ maximum_batch_steps = (n_contigs // starting_batch_size ).bit_length () - 1
583+ if maximum_batch_steps < len (sorted_batch_steps ):
584+ warnings .warn (
585+ f"Requested { len (sorted_batch_steps )} batch steps, but with a starting "
586+ f"batch size of { starting_batch_size } and { n_contigs } contigs, "
587+ f"only the first { maximum_batch_steps } batch steps can be used."
588+ )
589+ sorted_batch_steps = sorted_batch_steps [:maximum_batch_steps ]
576590
577591 # Get number of features
578592 # Following line is un-inferrable due to typing problems with DataLoader
@@ -591,8 +605,8 @@ def trainmodel(
591605 print ("\t N epochs:" , nepochs , file = logfile )
592606 print ("\t Starting batch size:" , dataloader .batch_size , file = logfile )
593607 batchsteps_string = (
594- ", " .join (map (str , sorted ( batchsteps_set ) ))
595- if batchsteps_set
608+ ", " .join (map (str , sorted_batch_steps ))
609+ if len ( sorted_batch_steps ) > 0
596610 else "None"
597611 )
598612 print ("\t Batchsteps:" , batchsteps_string , file = logfile )
@@ -603,7 +617,7 @@ def trainmodel(
603617 # Train
604618 for epoch in range (nepochs ):
605619 dataloader = self .trainepoch (
606- dataloader , epoch , optimizer , sorted ( batchsteps_set ) , time (), logfile
620+ dataloader , epoch , optimizer , sorted_batch_steps , time (), logfile
607621 )
608622
609623 # Save weights - Lord forgive me, for I have sinned when catching all exceptions
0 commit comments