How to apply uniform length batching(smart batching)? #9740
-
How to apply uniform length batching(smart batching)?Hi all! I have a question about applying a smart batching system like the above picture.
class ExampleDataset(Dataset):
def __init__(self, datas, tokenizer):
super(ExampleDataset, self).__init__()
self.tokenizer = tokenizer
tokenized = [self.tokenize(data) for data in tqdm(datas, desc='Tokenizing..')]
self.input_ids, self.attention_masks, self.labels = list(zip(*tokenized))
def tokenize(self, data):
encodings_dict = self.tokenizer(data)
return [
encodings_dict['input_ids'],
encodings_dict['attention_mask'],
encodings_dict['input_ids']
]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.attention_masks[idx], self.labels[idx]
class SmartBatchingSampler(Sampler):
def __init__(self, data_source: torch.utils.data.Dataset, batch_size=1):
super(SmartBatchingSampler, self).__init__(data_source)
self.batch_size = batch_size
self.data_source = data_source
sentence_lengths = [len(sentence[0]) for sentence in data_source]
sentence_indices = [idx for idx in range(len(data_source))]
pack_by_length = list(zip(sentence_lengths, sentence_indices))
sort_by_length = sorted(pack_by_length)
sentence_lengths, sentence_indices = zip(*sort_by_length)
self.bins = [
sentence_indices[i: i + batch_size]
for i in range(0, len(sentence_indices), batch_size)
]
self.bins = list(chain.from_iterable(self.bins))
self.drop_last = drop_last
def __iter__(self):
for ids in self.bins:
yield ids
def __len__(self):
return len(self.bins)
def shuffle(self, epoch):
np.random.shuffle(self.bins)
def collate_fn(batch):
def seq_length_(p):
return len(p[0])
max_seq_sample = max(batch, key=seq_length_)[0]
max_seq_size = len(max_seq_sample)
batch_size = len(batch)
input_ids = torch.zeros(batch_size, max_seq_size).fill_(0).long()
attention_masks = torch.zeros(batch_size, max_seq_size).fill_(0).long()
labels = torch.zeros(batch_size, max_seq_size).fill_(0).long()
for idx in range(batch_size):
sample = batch[idx]
sample_input_ids = sample[0]
sample_attention_masks = sample[1]
sample_labels = sample[2]
input_ids[idx].narrow(0, 0, len(sample_input_ids)).copy_(torch.LongTensor(sample_input_ids))
attention_masks[idx].narrow(0, 0, len(sample_attention_masks)).copy_(torch.LongTensor(sample_attention_masks))
labels[idx].narrow(0, 0, len(sample_labels)).copy_(torch.LongTensor(sample_labels))
return input_ids, attention_masks, labels
class ExampleDataModule(pl.LightningDataModule):
...
...
def train_dataloader(self):
sampler = SmartBatchingSampler(self.dataset['train'], batch_size=self.batch_size)
return DataLoader(
dataset=self.dataset['train'], # ExampleDataset class
sampler=sampler,
collate_fn=collate_fn,
) I have three questions.
Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
you can sort the data by |
Beta Was this translation helpful? Give feedback.
-
How can I? Can you explain more? |
Beta Was this translation helpful? Give feedback.
you can sort the data by
len
initially while creating the dataset itself. now just use a sequential sampler to avoid shuffle by just settingshuffle=False
inside dataloader. collate_fn looks good, although can be optimized a little bit. apart from that even if you useauto_scale_batch_size
, it will work just fine since your dataset will already be sorted by length.