Skip to content

CorpusReader expects number of training examples to be divisible by batch size #225

@oadams

Description

@oadams

if batch_size:
self.batch_size = batch_size
if num_train % batch_size != 0:
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
else:
# Dynamically change batch size based on number of training
# examples.
self.batch_size = int(num_train / 32.0)
if self.batch_size > 64:
# I was getting OOM errors when training with 4096 sents, as
# the batch size jumped to 128
self.batch_size = 64
# For now we hope that training numbers are powers of two or
# something... If not, crash before anything else happens.
if num_train % self.batch_size != 0:
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, self.batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))

This is an artificial limitation, since the remainder can always just be a smaller batch.

It also causes a bug where if the number of training examples is less than the batch size then num_train equals zero. The temporary fix in such cases is to reduce the batch size, but it's messy.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions