Binary graph classification explainer #7702
-
Hello, I'm trying to use the explainer module for binary graph classification problems, but I'm stuck with some errors. import os
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import global_mean_pool
from torch_geometric.explain import Explainer, CaptumExplainer
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, 1)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
x = global_mean_pool(x, batch)
return x
path = str(os.getcwd())
dataset = TUDataset(root="TUDataset", name="MUTAG")
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(0, 30):
for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = criterion(torch.squeeze(out), data.y.float())
loss.backward()
optimizer.step()
optimizer.zero_grad()
explainer = Explainer(model=model,
algorithm=CaptumExplainer("IntegratedGradients"),
explanation_type="model",
edge_mask_type="object",
model_config=dict(mode="binary_classification",
task_level="graph",
return_type="raw"))
dummy = torch.zeros(dataset[0].x.shape[0], dtype=int)
explanation = explainer(dataset[0].x, dataset[0].edge_index, batch=dummy) This is the Traceback I received. Traceback (most recent call last):
File "/media/takaogahara/external/2_projects/gnn-toolkit/temp/explain_test/test.py", line 63, in <module>
explanation = explainer(dataset[0].x, dataset[0].edge_index, batch=dummy)
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/torch_geometric/explain/explainer.py", line 198, in __call__
explanation = self.algorithm(
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/torch_geometric/explain/algorithm/captum_explainer.py", line 162, in forward
attributions = self.attribution_method.attribute(
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/log/__init__.py", line 42, in wrapper
return func(*args, **kwargs)
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py", line 274, in attribute
attributions = _batch_attribution(
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/attr/_utils/batching.py", line 78, in _batch_attribution
current_attr = attr_method._attribute(
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py", line 351, in _attribute
grads = self.gradient_func(
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/_utils/common.py", line 487, in _run_forward
return _select_targets(output, target)
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/_utils/common.py", line 501, in _select_targets
return _verify_select_column(output, cast(int, target.item()))
File "/home/takaogahara/virtualenvs/gnn-dev/lib/python3.10/site-packages/captum/_utils/common.py", line 548, in _verify_select_column
return output[(slice(None), *target)]
IndexError: index 1 is out of bounds for dimension 1 with size 1 Has anyone used the new Explainer module with graph binary classification and could help me? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
you are passing a batch of size 0 to the explainer function and you are getting is an IndexError. you need to pass a batch of data to the explainer function. You can do this by creating a batch of data with the same number of elements as the dataset you are using. something like this:
|
Beta Was this translation helpful? Give feedback.
-
Mh, I am afraid import os
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.datasets import TUDataset
from torch_geometric.explain import CaptumExplainer, Explainer
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
x = global_mean_pool(x, batch)
return x
path = str(os.getcwd())
dataset = TUDataset(root="TUDataset", name="MUTAG")
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(0, 30):
for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
explainer = Explainer(
model=model,
algorithm=CaptumExplainer("IntegratedGradients"),
explanation_type="model",
edge_mask_type="object",
model_config=dict(mode="multiclass_classification", task_level="graph",
return_type="raw"),
)
explanation = explainer(data.x, data.edge_index, batch=data.batch, index=0) |
Beta Was this translation helpful? Give feedback.
-
Fixed this via #7787. |
Beta Was this translation helpful? Give feedback.
Fixed this via #7787.