Explain graph level prediction on heterogeneous data. #9185
Answered
by
MarcPoulin1
MarcPoulin1
asked this question in
Q&A
-
I am trying to use CaptumExplainer('IntegratedGradients') to explain graph level prediction on heterogeneous graphs. from torch_geometric.explain import Explainer, CaptumExplainer
explainer = Explainer(model, algorithm=CaptumExplainer('IntegratedGradients'), explaination_type='model',
node_mask_type='attributes', edge_mask_type='object',
model_config={'mode': 'regression', 'task_level': 'graph', 'return_type': 'raw'}
explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict) The issue is that my model forward function expects an heterobatch as input because it is needed for pooling. Part of the model foward: x_dict_max = {key: self.global_max_pooling(x_dict[key], batch[key].batch for key in x_dict}
x_dict_add = {key: self.global_add_pooling(x_dict[key], batch[key].batch for key in x_dict}
x_dict_mean = {key: self.global_mean_pooling(x_dict[key], batch[key].batch for key in x_dict}
Thank you for any help. |
Beta Was this translation helpful? Give feedback.
Answered by
MarcPoulin1
Apr 11, 2024
Replies: 1 comment
-
I was able to modify my forward call to: def forward(self, x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor],
edge_attr_dict: Dict[Tuple[str, str, str], Tensor], batch: Batch) -> Tensor: and run: explainer(x=test_batch.x_dict, edge_index=test_batch.edge_index_dict, edge_attr_dict=test_batch.edge_attr_dict,
batch=test_batch) Everything is working now. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
MarcPoulin1
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was able to modify my forward call to:
and run:
Everything is working now.