-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdata_loader.py
More file actions
67 lines (60 loc) · 2.84 KB
/
data_loader.py
File metadata and controls
67 lines (60 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import config
import numpy as np
from torch.utils.data import Dataset, DataLoader
class LographDataset(Dataset):
def __init__(self, samples, smpl_index, tmpl_list, template_map):
self.dataset = []
self.group_map = {}
self.process(samples, smpl_index, tmpl_list, template_map)
def process(self, samples, smpl_index, tmpl_list, template_map):
for idx in smpl_index:
indice = samples[idx][0]
word = [template_map[tmpl_list[j]] for j in samples[idx][0]]
label = samples[idx][1]
group = samples[idx][2]
if group not in self.group_map:
self.group_map[group] = len(self.group_map)
group_id = self.group_map[group]
self.dataset.append([word, label, group_id, indice])
return self
def __getitem__(self, idx):
word = self.dataset[idx][0]
label = self.dataset[idx][1]
group = self.dataset[idx][2]
indice = self.dataset[idx][3]
return word, label, group, indice
def __len__(self):
return len(self.dataset)
def lograph_collate_fn(batch):
batch_size = len(batch)
words = [x[0] for x in batch]
labels = [x[1] for x in batch]
groups = [x[2] for x in batch]
lengths = [len(x[0]) for x in batch]
word_counts = [[len(v) for v in x[0]] for x in batch]
max_length = max(lengths)
max_word_count = max([max(x) for x in word_counts])
indices = [(x[3]+[-1]*max_length)[:max_length] for x in batch]
word_tensor = torch.LongTensor(batch_size, max_length, max_word_count).fill_(0)
group_tensor = torch.LongTensor(batch_size).fill_(0)
label_tensor = torch.LongTensor(batch_size).fill_(0)
mask_tensor = torch.ByteTensor(batch_size, max_length, max_word_count).fill_(0)
for i, (word,label,group) in enumerate(zip(words, labels, groups)):
for j,num in enumerate(word_counts[i]):
word_tensor[i, j, :num] = torch.LongTensor(word[j])
mask_tensor[i, j, :num] = torch.tensor([1]*num, dtype=torch.uint8)
label_tensor[i] = label
group_tensor[i] = group
indice_array = np.array(indices)
return word_tensor, label_tensor, group_tensor, mask_tensor, indice_array #indice_tensor
def convert_to_training_data_loader(samples, train_index, dev_index, tmpl_list, template_map, pseudo_label=None):
train_dataset = LographDataset(samples, train_index, tmpl_list, template_map)
dev_dataset = LographDataset(samples, dev_index, tmpl_list, template_map)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lograph_collate_fn)
dev_loader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lograph_collate_fn)
return train_loader, dev_loader
def convert_to_testing_data_loader(samples, test_index, tmpl_list, template_map):
test_dataset = LographDataset(samples, test_index, tmpl_list, template_map)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, collate_fn=lograph_collate_fn)
return test_loader