GNNExplainer on AttentiveFP #5356
-
I'm trying to use the GNNExplainer on AttentiveFP, but I'm having some difficulties and I can't say for sure if the problem is me or if I should open an issue. I get a I adapted the example available from GNNExplainer to reproduce the described error. import os.path as osp
import torch
import torch.nn as nn
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import AttentiveFP, GNNExplainer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = 'bbbp'
path = osp.join(osp.dirname(osp.realpath(__file__)), 'MoleculeNet')
dataset = MoleculeNet(path, dataset)
data = dataset[0]
net = AttentiveFP(in_channels=data.num_node_features,
hidden_channels=64,
out_channels=1,
edge_dim=data.num_edge_features,
num_layers=1,
num_timesteps=1,
dropout=0.2)
model = net.to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
data_batch = torch.zeros(data.x.shape[0], dtype=int, device=device)
x = x.float()
for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
logit = model(x=x, edge_index=edge_index,
edge_attr=edge_attr, batch=data_batch)
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logit, data.y)
loss.backward()
optimizer.step()
explainer = GNNExplainer(model, epochs=200, return_type='log_prob')
_, edge_mask = explainer.explain_graph(x=x, edge_index=edge_index,
edge_attr=edge_attr)
print(edge_mask) This is the Traceback I received.
Has anyone used GNNExplainer in AttentiveFP and could help me? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
This looks like a bug in |
Beta Was this translation helpful? Give feedback.
This looks like a bug in
Explainer.get_prediction
will try to send in a fix soon.The problem is that in attentive_fp the
batch
argument comes after theedge_attr
.