-
Hello, I am relatively new to pytorch and pyg and had some questions with regards to the explainability functionality. I have the following heterogeneous dataset:
and I am building a GATConv model by using the default GAT model in
to get the below model:
The model outputs a dictionary of tensors keyed by the node types, which is expected behaviour according to the docs. However, the problem arises when I try and run this model through the For example, using
gives the following error:
Could someone please advise if I am doing something wrong with the Explainer, or if this is expected behaviour? If this is expected, is there any way I can use the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think the issue is that your heterogeneous model is still expected to return single tensors, not dictionaries of tensors. As such, what you can do is something like class HeteroSAGE(torch.nn.Module):
def __init__(self, metadata):
super().__init__()
self.graph_sage = to_hetero(GraphSAGE(-1, 32, num_layers=2), metadata)
self.lin = torch.nn.Linear(32, 1)
def forward(self, x_dict, edge_index_dict) -> torch.Tensor:
return self.lin(self.graph_sage(x_dict, edge_index_dict)['paper']) |
Beta Was this translation helpful? Give feedback.
I think the issue is that your heterogeneous model is still expected to return single tensors, not dictionaries of tensors. As such, what you can do is something like