-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataloading.py
More file actions
44 lines (37 loc) · 1.46 KB
/
dataloading.py
File metadata and controls
44 lines (37 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch
from torch.utils.data import Dataset
import dgl
import copy
class DSDataset(Dataset):
'''
down-sampling dataset
sample equal-sized positive and negative samples for each batch
'''
def __init__(self, graph) -> None:
super().__init__()
self.num_nodes = graph.num_nodes()
self.train_mask = graph.ndata['train_mask'].bool().numpy()
self.val_mask = graph.ndata['val_mask'].bool().numpy()
self.test_mask = graph.ndata['test_mask'].bool().numpy()
self.labels = graph.ndata['label'].numpy()
self.pos_nids = np.arange(self.num_nodes)[(self.labels == 1) * self.train_mask]
self.neg_nids = np.arange(self.num_nodes)[(self.labels == 0) * self.train_mask]
self.pos_list = self.pos_nids
self.resample()
def resample(self):
'''
resample a list of negative nodes, with the same size of positive nodes
'''
self.neg_list = np.random.permutation(self.neg_nids)[:len(self.pos_nids)]
def __getitem__(self, index):
'''
return format: [neg_x, neg_y, pos_x, pos_y], where pos_x and neg_x are node ids
'''
pos_x = self.pos_list[index]
neg_x = self.neg_list[index]
pos_x = torch.LongTensor([pos_x])
neg_x = torch.LongTensor([neg_x])
return neg_x, torch.zeros(len(neg_x)).long(), pos_x, torch.ones(len(pos_x)).long()
def __len__(self):
return len(self.pos_nids)