Explain module with hetero graph and edges prediction, not working #8714
-
Hello, I am interested in a specific case of the explain library which I cannot manage to solve. More precisely, I have a model on HeteroData Graph that is performing edge classification. In particular, my model output is not a unique tensor but rather a 'dictionary' of tensors. (One tensor for each type of data). On this example It is mentioned that 'It is assumed that model outputs a single tensor' which is not the case in my example. Indeed, if I try to call the explainer I obtain the error 'dict' object has no attribute 'argmax' ... I would like to know how I can, and I should use the explain library with my settings ? (I already checked the examples page but it does not help) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
@arthurserres Could I see your script? What does your model look like? I feel like the example here will still be a reasonable reference for your use-case. |
Beta Was this translation helpful? Give feedback.
Yes, outputs needs to be a
torch.Tensor
. Your solution works. Alternatively, you can stack the outputs of the two node types together before returning it.