-
I'm trying to use the explainer module for multiclass graph classification problems, but I'm stuck with some errors.
code class Cheb(nn.Module):
def __init__(self, feature_dim, hidden_dim=32, output_dim=256, dropout=0.4, use_residue=False):
super(Cheb, self).__init__()
self.dropout = dropout
self.hidden = hidden_dim
self.cconv1 = ChebConv(feature_dim, hidden_dim, aggr='sum', K=3)
self.cconv2 = ChebConv(hidden_dim, hidden_dim * 2, aggr='sum', K=3)
self.cconv3 = ChebConv(hidden_dim * 2, hidden_dim * 4, aggr='sum', K=3)
self.relu = nn.LeakyReLU()
self.dropout = nn.Dropout(dropout)
self.classification = nn.Sequential(
nn.Linear(self.hidden * 4, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 4)
)
def forward(self, data):
# 获取小分子和蛋白质输入的结构信息
protein_feature, protein_index, protein_batch = data.x, data.edge_index, data.batch
# 对小分子进行卷积操作
protein_feature = self.cconv1(protein_feature, protein_index)
protein_feature = self.relu(protein_feature)
# protein_feature = self.bn1(protein_feature)
protein_feature = self.cconv2(protein_feature, protein_index)
protein_feature = self.relu(protein_feature)
# protein_feature = self.bn2(protein_feature)
protein_feature = self.cconv3(protein_feature, protein_index)
# protein_feature = self.relu(protein_feature)
# protein_feature = self.bn3(protein_feature)
# 对卷积后的小分子进行图的最大值池化
# print(protein_feature.view(16, -1, self.hidden * 4))
protein_feature = gmp(protein_feature, protein_batch)
protein_feature = self.classification(protein_feature)
return protein_feature
loader = DataLoader(test_data, batch_size=1, shuffle=False)
for batch in loader:
print(batch.x)
explainer = Explainer(model=model, algorithm=PGExplainer(epochs=30, lr=0.003),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=dict(mode='multiclass_classification', task_level='graph', return_type='raw', ),
threshold_config=dict(threshold_type='topk', value=10)
)
for epoch in range(20):
for batch in loader:
loss = explainer.algorithm.train(epoch, model, batch.x, batch.edge_index, target=batch.y) data
|
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
Jul 24, 2023
Replies: 1 comment 1 reply
-
You should modify your model to only take in the actual tensors as input, e.g.: def forward(self, x, edge_index, batch) These parameters will get passed to the model in explainer.algorithm.train(epoch, model, batch.x, batch.edge_index, batch=batch.batch, target=batch.y) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
JFetish
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You should modify your model to only take in the actual tensors as input, e.g.:
These parameters will get passed to the model in