How do node features get passed around in MessagePassing base class? #2120
Answered
by
rusty1s
HongtaoYang
asked this question in
Q&A
-
I'm confused about how node features get passed around in MessagePassing base class? See the toy example bolow. import numpy as np
import torch
from torch_geometric.nn import MessagePassing
# construct a simple graph, node features are 2 dimensional vectors whose value is simply the node id.
# e.g. for node 0, the feature is [0, 0]
edge_index = np.array([[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5], [1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]])
x = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
edge_index = torch.from_numpy(edge_index)
x = torch.from_numpy(x)
class MyMessagePassing(MessagePassing):
def __init__(self, in_channels, out_channels):
super(MyMessagePassing, self).__init__(aggr='add')
def message(self, x_j: torch.Tensor) -> torch.Tensor:
print(x)
print(x_j)
return x_j
def forward(self, x, edge_index):
x = x + 0.1 # toggle this does not change the print(x) output.
self.propagate(edge_index, x=x)
gcn = MyMessagePassing(2, 2)
output = gcn(x, edge_index) In the example, I simply inherit The problem is no matter what operations I apply on |
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
Feb 16, 2021
Replies: 1 comment 1 reply
-
The def message(self, x: torch.Tensor, x_j: torch.Tensor):
print(x)
print(x_j) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
HongtaoYang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The
x
comes from global scope, and is therefore unmodified. Fixable by changing themessage
header to