How to estimate a nodes feature vector based on neighbours data #2797
-
A novice in need of help: I have a very simple directed graph 1 -> 2 -> 3 -> 4 with one feature per node lets say it could be speed. A graph is generated by setting the n = 1st node speed for example to 100 and n + 1 speed is calculated by transforming the n nodes speed with a randomly generated polynomial function with added random bias. This way there exists a function between nodes that could be learned. I generate these kinds of graphs 1000 by drawing starting speed randomly from [90, 130] and keeping the functions the same between the nodes for all the draws. MSE loss is calculated from unbiased values. Network definition with GraphSAGE-layer: class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers):
super(Net, self).__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
for i in range(num_layers):
in_channels = in_channels if i == 0 else hidden_channels
if i == range(num_layers)[-1]:
self.convs.append(SAGEConv(in_channels, 1)) # Outputs 1d weight
else:
self.convs.append(SAGEConv(in_channels, hidden_channels))
def forward(self, data):
x, edge_index = data.x, data.edge_index
for i in range(self.num_layers):
x = self.convs[i](x, edge_index)
if i != self.num_layers - 1:
x = x.relu()
return x With two layers and 16 hidden_channels the Network converges. However, I am a bit stuck from now on how to use the learned data. Especially, I want to estimate the node's speed by transforming the previous node's speed with the learned parameters (chaining the message). More on transforming the previous node's signal to the desired one by carrying the message through the graph. Node -> function -> Node Calling I think I need to call the propagate step to achieve the wanted result and chain them (looping)? This way we get the information from the neighbors and the message is carried through the graph? To achieve this I need probably write a new method for the Net class. Could I just call message, message_and_aggregate or propagate from the SAGEConv class? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I thin your GNN implementation looks fine to me. If you want to change specific input features (say data.x[0] = 100
output = model(data)
output[1] # New estimation for node 1 that takes the modification of node 0 into account |
Beta Was this translation helpful? Give feedback.
I thin your GNN implementation looks fine to me. If you want to change specific input features (say
data.x[0] = 100
), you will need to re-evaluate the model on the complete graph (as there is a inter-dependency of outputs between nodes), e.g.: