Is there a problem with my generated edge_index? #3826
-
Hello! import torch
import torch.nn.functional as F
from torch_geometric.data.data import Data
from torch_geometric.data.in_memory_dataset import InMemoryDataset
from torch_geometric.graphgym.loader import get_loader
from torch_geometric.nn import GATConv, Sequential
from torch_geometric.transforms import NormalizeFeatures
from torch.nn import Linear, ReLU
import random
def geneate_DCG_edge_index(nodes_num):
return torch.tensor([[i, j] for i in range(nodes_num) for j in range(nodes_num) if i != j]).long().t().contiguous()
def feat_generator(gt=1):
return torch.rand((1, 768)) if gt else torch.rand((1, 768))-1
# return torch.ones((1, 768)) if gt is not None else torch.zeros((1, 768))-1
def get_data_list():
data_list = []
for graph_id in range(20):
nodes_num = random.randint(1, 8)
y = torch.randint(0, 2, (nodes_num,)).long()
edge_index = geneate_DCG_edge_index(nodes_num)
edge_attr = None
x = torch.cat([feat_generator(gt) for gt in y], 0)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data_list.append(data)
return data_list
class GNNNodeDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(GNNNodeDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
return ['data.pt']
def process(self):
# Read data into huge `Data` list.
data_list = get_data_list()
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
dataset = GNNNodeDataset(root='datasets/rand', transform=NormalizeFeatures())
print('dataset.num_features:', dataset.num_features)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataloader = get_loader(dataset, "full_batch", 4, shuffle=False)
class GAT(torch.nn.Module):
def __init__(self, num_features, hidden_channels_list, num_classes):
super(GAT, self).__init__()
torch.manual_seed(12345)
hns = [num_features] + hidden_channels_list
conv_list = []
for idx in range(len(hidden_channels_list)):
conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
conv_list.append(ReLU(inplace=True),)
self.convseq = Sequential('x, edge_index', conv_list)
self.linear = Linear(hidden_channels_list[-1], num_classes)
def forward(self, x, edge_index):
# edge_index = torch.tensor([[], []], dtype=torch.long).to(x.device)
x = self.convseq(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear(x)
return x
model = GAT(num_features=dataset.num_features, hidden_channels_list=[200, 100], num_classes=dataset.num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
for batch in train_dataloader:
optimizer.zero_grad()
batch.to(torch.device(device))
out = model(batch.x, batch.edge_index)
loss = criterion(out, batch.y)
loss.backward()
optimizer.step()
return loss
for epoch in range(1, 201):
loss = train()
if epoch % 1 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
print() Executing the code above gives me: Epoch: 001, Loss: 1.5121
Epoch: 002, Loss: 0.8916
Epoch: 003, Loss: 0.8128
Epoch: 004, Loss: 0.6889
Epoch: 005, Loss: 0.6904
Epoch: 006, Loss: 0.6570
Epoch: 007, Loss: 0.6481
Epoch: 008, Loss: 0.6550
Epoch: 009, Loss: 0.7098
Epoch: 010, Loss: 0.6991
Epoch: 011, Loss: 0.7111
Epoch: 012, Loss: 0.6546
Epoch: 013, Loss: 0.6791
Epoch: 014, Loss: 0.7347
Epoch: 015, Loss: 0.6288
Epoch: 016, Loss: 0.6552
Epoch: 017, Loss: 0.6142
Epoch: 018, Loss: 0.6456
Epoch: 019, Loss: 0.6616
Epoch: 020, Loss: 0.6083
Epoch: 021, Loss: 0.5717
Epoch: 022, Loss: 0.8385
Epoch: 023, Loss: 0.7059
Epoch: 024, Loss: 0.6492
Epoch: 025, Loss: 0.6575
Epoch: 026, Loss: 0.6623
Epoch: 027, Loss: 0.6333
Epoch: 028, Loss: 0.6741
Epoch: 029, Loss: 0.6538
Epoch: 030, Loss: 0.6098
Epoch: 031, Loss: 0.6129
Epoch: 032, Loss: 0.6211
Epoch: 033, Loss: 0.6055
Epoch: 034, Loss: 0.5394
Epoch: 035, Loss: 0.5429
Epoch: 036, Loss: 0.6780
Epoch: 037, Loss: 0.6013
Epoch: 038, Loss: 0.5889
Epoch: 039, Loss: 0.5840
Epoch: 040, Loss: 0.5948
Epoch: 041, Loss: 0.6023
Epoch: 042, Loss: 0.5885
Epoch: 043, Loss: 0.6713
Epoch: 044, Loss: 0.6410
Epoch: 045, Loss: 0.6219
Epoch: 046, Loss: 0.5462
Epoch: 047, Loss: 0.5900
Epoch: 048, Loss: 0.6372
Epoch: 049, Loss: 0.6223
Epoch: 050, Loss: 0.6211
Epoch: 051, Loss: 0.6066
Epoch: 052, Loss: 0.6613
Epoch: 053, Loss: 0.6040
Epoch: 054, Loss: 0.5901
Epoch: 055, Loss: 0.6139
Epoch: 056, Loss: 0.5747
Epoch: 057, Loss: 0.5887
Epoch: 058, Loss: 0.6412
Epoch: 059, Loss: 0.6455
Epoch: 060, Loss: 0.6587
Epoch: 061, Loss: 0.6193
Epoch: 062, Loss: 0.5903
Epoch: 063, Loss: 0.6303
Epoch: 064, Loss: 0.5928
Epoch: 065, Loss: 0.5696
Epoch: 066, Loss: 0.5917
Epoch: 067, Loss: 0.5966
Epoch: 068, Loss: 0.5640
Epoch: 069, Loss: 0.5661
Epoch: 070, Loss: 0.6423
Epoch: 071, Loss: 0.5574
Epoch: 072, Loss: 0.5851
Epoch: 073, Loss: 0.5589
Epoch: 074, Loss: 0.5795
Epoch: 075, Loss: 0.5988
Epoch: 076, Loss: 0.6208
Epoch: 077, Loss: 0.6584
Epoch: 078, Loss: 0.6360
Epoch: 079, Loss: 0.5667
Epoch: 080, Loss: 0.7061
Epoch: 081, Loss: 0.6214
Epoch: 082, Loss: 0.6855
Epoch: 083, Loss: 0.6081
Epoch: 084, Loss: 0.5112
Epoch: 085, Loss: 0.5991
Epoch: 086, Loss: 0.5900
Epoch: 087, Loss: 0.6510
Epoch: 088, Loss: 0.5862
Epoch: 089, Loss: 0.5917
Epoch: 090, Loss: 0.5801
Epoch: 091, Loss: 0.5779
Epoch: 092, Loss: 0.5850
Epoch: 093, Loss: 0.6232
Epoch: 094, Loss: 0.6390
Epoch: 095, Loss: 0.6712
Epoch: 096, Loss: 0.5919
Epoch: 097, Loss: 0.5939
Epoch: 098, Loss: 0.6714
Epoch: 099, Loss: 0.6498
Epoch: 100, Loss: 0.6493
Epoch: 101, Loss: 0.6734
Epoch: 102, Loss: 0.6623
Epoch: 103, Loss: 0.6686
Epoch: 104, Loss: 0.6213
Epoch: 105, Loss: 0.6352
Epoch: 106, Loss: 0.6083
Epoch: 107, Loss: 0.6454
Epoch: 108, Loss: 0.5559
Epoch: 109, Loss: 0.6640
Epoch: 110, Loss: 0.6626
Epoch: 111, Loss: 0.6582
Epoch: 112, Loss: 0.6994
Epoch: 113, Loss: 0.5918
Epoch: 114, Loss: 0.5922
Epoch: 115, Loss: 0.7155
Epoch: 116, Loss: 0.5989
Epoch: 117, Loss: 0.5973
Epoch: 118, Loss: 0.5978
Epoch: 119, Loss: 0.6621
Epoch: 120, Loss: 0.6062
Epoch: 121, Loss: 0.6245 The model doesn't seem to converge But when I uncomment the next line, which means delete all the edges # edge_index = torch.tensor([[], []], dtype=torch.long).to(x.device) The model converges quickly Epoch: 001, Loss: 0.3095
Epoch: 002, Loss: 0.0964
Epoch: 003, Loss: 0.0031
Epoch: 004, Loss: 0.0004
Epoch: 005, Loss: 0.0001
Epoch: 006, Loss: 0.0000
Epoch: 007, Loss: 0.0000
Epoch: 008, Loss: 0.0000
Epoch: 009, Loss: 0.0001
Epoch: 010, Loss: 0.0001
Epoch: 011, Loss: 0.0000
Epoch: 012, Loss: 0.0000
Epoch: 013, Loss: 0.0001
Epoch: 014, Loss: 0.0000 Is there something wrong with the way I generate edge_index, or is this normal for complete graphs? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I cannot really reproduce this, sorry. In both cases, the loss stays close to Note that |
Beta Was this translation helpful? Give feedback.
I cannot really reproduce this, sorry. In both cases, the loss stays close to
~0.6
.Note that
GATConv
does not really have the best power to maintain central node information. Therefore, one thing that is often done and that may improve performance is to combineGATConv
with learnable skip connections, see here.