GCNConv,pass a Tensor of size [GraphNum, PointNum, FeatureDim] and [2, EdgeNum] in the forward method #8682
-
Hello everyone! class GraphNeuralNetwork(nn.Module):
def __init__(
self,
num_layers: int,
input_dim: int,
hidden_dim: int,
):
super(GraphNeuralNetwork, self).__init__()
self.num_layers = num_layers
self.gcn_layers = nn.ModuleList()
self.gcn_layers.append(gnn.GraphConv(input_dim, hidden_dim))
self.gcn_layers.append(nn.ReLU(inplace=True))
self.gcn_layers.append(gnn.GraphNorm(hidden_dim))
for _ in range(num_layers):
self.gcn_layers.append(gnn.GraphConv(hidden_dim, hidden_dim))
self.gcn_layers.append(nn.ReLU(inplace=True))
self.gcn_layers.append(gnn.GraphNorm(hidden_dim))
self.regression_head = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 2),
)
def forward(
self,
x: torch.Tensor, #size [graph_num, point_num, dim]
edge_index: torch.Tensor #size [2, edge_num]
) -> Dict[str, torch.Tensor]:
for layer in self.gcn_layers:
if isinstance(layer, gnn.GraphConv):
x = layer(x, edge_index)
else:
x = layer(x)
def test():
num_layers = 1
input_dim = 64
hidden_dim = 128
B, N, C = 8, 64, input_dim
x = torch.randn(B, N, C)
edges = torch.tensor([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)], dtype=torch.long).permute(1,0)
gnn_model = GraphNeuralNetwork(num_layers, input_dim, hidden_dim)
res = gnn_model(x, edge_index)
if __name__ == '__main__':
test() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 9 replies
-
There exists a few layers in PyG that support this "static graph" computation where the graph is static across a set of batches, and thus allow for |
Beta Was this translation helpful? Give feedback.
There exists a few layers in PyG that support this "static graph" computation where the graph is static across a set of batches, and thus allow for
[batch_size, num_nodes, num_features]
feature tensors. Supported layers are documented here in the "static" column.