Returning multiple subsets per batch when I set small size of the batch_size #9121
-
Dear PyG community, When I checked the status of the Dataloader(in my case, for fold, (train_idx, val_idx, test_idx) in enumerate(zip(*k_fold(DataList, folds))):
print(f'FOLD {fold}')
print('-------------------------------------------')
kf_train_data = Data(edge_index = BaseKG.edge_index,
edge_label = DataList.edge_label[train_idx],
edge_label_index = DataList.edge_index[:, train_idx],
edge_class = DataList.edge_class[train_idx],
x=[opt_rotate.node_emb.weight, DataList.x],
num_nodes=opt_rotate.num_nodes)
kf_val_data = Data(...)
kf_test_data = Data(...)
tr_sampler=ImbalancedSampler(dataset = kf_train_data.edge_label)
val_sampler=ImbalancedSampler(dataset = kf_val_data.edge_label)
ts_sampler=ImbalancedSampler(dataset = kf_test_data.edge_label)
train_loader = LinkNeighborLoader(kf_train_data, edge_label_index=kf_train_data.edge_label_index, edge_label=kf_train_data.edge_label,
batch_size=64, shuffle=False, neg_sampling_ratio=0.0, num_neighbors=[2,2], disjoint=True, sampler=tr_sampler)
val_loader = LinkNeighborLoader(kf_val_data, ...)
test_loader = LinkNeighborLoader(kf_test_data, ...)
for data in tqdm(train_loader):
data = data.to(device)
data.edge_class = data.edge_class[data.input_id]
data.edge_index_class = torch.zeros(len(data.edge_index[0])).to(device)
for i in range(len(data.edge_index_class)):
if ((data.edge_index[1][i] in data.edge_label_index[0]) or (data.edge_index[1][i] in data.edge_label_index[1])):
data.edge_index_class[i] = data.edge_class[(data.edge_index[1][i]==data.edge_label_index[0])
|(data.edge_index[1][i]==data.edge_label_index[1])]
else:
data.edge_index_class[i] = torch.max(data.edge_index_class[(data.edge_index[1][i]==data.edge_index[0])])
print(data) When I set the batch size to more than 128, it returns one subset per batch. ...
train_loader = LinkNeighborLoader(kf_train_data, edge_label_index=kf_train_data.edge_label_index, edge_label=kf_train_data.edge_label,
batch_size=128, shuffle=False, neg_sampling_ratio=0.0, num_neighbors=[2,2], disjoint=True, sampler=tr_sampler)
val_loader = LinkNeighborLoader(...)
test_loader = LinkNeighborLoader(...)
... Does anyone know why this phenomenon is evoked? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I don't see the issue yet. It's just that |
Beta Was this translation helpful? Give feedback.
I don't see the issue yet. It's just that
tqdm
may not update itself on every iteration (you can see that the iterations increase from 2 to 4 to 6)