How to implement graph coarsening under PyG framework ? #6536
Answered
by
SongZan222
SongZan222
asked this question in
Q&A
-
Coarsening of graphs is built into some of the pooling methods provided by PyG, but I want functions that only coarsening graphs. Does PyG provide such a function? |
Beta Was this translation helpful? Give feedback.
Answered by
SongZan222
Jan 28, 2023
Replies: 1 comment
-
The problem has been solved. class SAG_h(torch.nn.Module):
def __init__(self,c_in,c_out,ratio):
super(SAG_h,self).__init__()
self.pool = SAGPooling(c_in,ratio)
self.conv = GCNConv(c_in,c_out)
self.readout = sum_aggr
def forward(self,X,Edge,Batch):
x, edge_index,batch = X,Edge,Batch
x,edge_index,_,batch,_,_= self.pool(
x=x,
edge_index=edge_index,
batch=batch
)
y = self.conv(x,edge_index)
y = sum_aggr(
x = y,
index= batch,
)
return x,edge_index,batch,y
class Net_h(torch.nn.Module):
def __init__(self):
super(Net_h, self).__init__()
self.SAGh_1 = SAG_h(dataset.num_node_features,HIDEN_DIMENSION,RATIO)
self.SAGh_2 = SAG_h(dataset.num_node_features,HIDEN_DIMENSION,RATIO)
self.SAGh_3 = SAG_h(dataset.num_node_features,HIDEN_DIMENSION,RATIO)
self.fc1 = torch.nn.Linear(3*HIDEN_DIMENSION,HIDEN_DIMENSION)
self.fc2 = torch.nn.Linear(HIDEN_DIMENSION,dataset.num_classes)
def forward(self, data):
X ,Edge_index, Batch = data.x,data.edge_index,data.batch
X, Edge_index, Batch, y_1 = self.SAGh_1(X ,Edge_index, Batch)
X, Edge_index, Batch, y_2 = self.SAGh_2(X ,Edge_index, Batch)
X, Edge_index, Batch, y_3 = self.SAGh_3(X, Edge_index, Batch)
y = torch.cat((y_1,y_2,y_3),dim=1)
y = self.fc1(y)
y = self.fc2(y)
return torch.nn.functional.log_softmax(y, dim=-1) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
SongZan222
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The problem has been solved.
Please refer to the code below to implement hierarchical pooling.