How to spearate message passing by edge class? #9134
-
Hi, always thank you for your discussions. When I have a dataset containing components as below, I'd like to update the node features by running message passing per print(GraphList[0]) # There are 4 graph objects
# Data(edge_index=[2, 9859], edge_label=[9859], x=[19392, 4], edge_class=[9859]) # edge_class follows the index of graph in GraphList
print(torch.bincount(GraphList[0].edge_label))
# tensor([8246, 1613]) I just created a simple GAT model for link prediction. class MyModel(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(MyModel, self).__init__()
...
# GAT
self.conv1 = GATConv(in_channels, hidden_channels, heads=12) # in_channel, hidden_channel, heads
self.lin1 = torch.nn.Linear(in_channels, hidden_channels * 12)
self.norm1 = LayerNorm(hidden_channels * 12)
self.conv2 = GATConv(hidden_channels * 12, out_channels, heads=1, concat=True) # hidden_channel * heads, out_channel, heads
self.lin2 = torch.nn.Linear(hidden_channels * 12, out_channels)
def forward(self, l_data, s_data, edge_index, node_ids, neighbor_cl_ids):
x1 = l_data #(107940, 256)
# This block is for mapping node features from s_data in each batch, so tmp_filled only contains id-matched features
tmp_filled = torch.zeros((len(l_data), 256), dtype=torch.float).to(device) #(107940, 256)
for e_i_class in torch.unique(neighbor_cl_ids):
indices = (neighbor_cl_ids == e_i_class).nonzero(as_tuple=True)[0]
key_n_id = torch.unique(node_ids[edge_index[:, indices]])
s_n_id = np.in1d(cID_List, key_n_id.cpu()).nonzero()[0]
tmp_filled[cID_List[s_n_id]] = s_data[int(e_i_class.item())][s_n_id].float()
# trial2: Concat
x = torch.cat((x1, tmp_filled), dim=1) #(107940, 512)
x_1, a1 = self.conv1(x, edge_index, return_attention_weights=True)
x = F.leaky_relu(self.norm1(x_1))
x = F.dropout(x, p = 0.2, training = self.training)
x_2, a2 = self.conv2(x, edge_index, return_attention_weights=True)
x = x_2 + self.lin2(x)
return x, a1, a2 # For the training and validation,
...
model.train()
for data in tqdm(train_loader):
data = data.to(device)
data.edge_class = data.edge_class[data.input_id]
data.edge_index_class = torch.zeros(len(data.edge_index[0])).to(device)
# This for loop fills the edge_index_class for the edge_index which is sampled for the message passing from the seed edges
for i in range(len(data.edge_index_class)):
if ((data.edge_index[1][i] in data.edge_label_index[0]) or (data.edge_index[1][i] in data.edge_label_index[1])):
data.edge_index_class[i] = data.edge_class[(data.edge_index[1][i]==data.edge_label_index[0])
|(data.edge_index[1][i]==data.edge_label_index[1])]
else:
data.edge_index_class[i] = torch.max(data.edge_index_class[(data.edge_index[1][i]==data.edge_index[0])])
optimizer.zero_grad()
z, a1, a2 = model(data.x[0], data.x[1], data.n_id, data.edge_index_class, data.edge_index)
out = ((z[data.edge_label_index[0]] * z[data.edge_label_index[1]]).sum(dim=-1)).view(-1)
tr_loss = criterion(out, data.edge_label.float())
tr_losses += tr_loss.item()
tr_loss.backward()
optimizer.step() There are two questions about the code implementation Thank you for reading this question. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I haven't read your code in detail, but IMO you should utilize different message passing layers/parameters in order to perform edge class dependent propagation. This should be part of the model, e.g.: def forward(self, x, edge_index, edge_class):
out = 0
for i in range(edge_class.max()):
out = out + conv[i](x, edge_index[:, edge_class = i]
|
Beta Was this translation helpful? Give feedback.
I haven't read your code in detail, but IMO you should utilize different message passing layers/parameters in order to perform edge class dependent propagation. This should be part of the model, e.g.: