Is there a working example of Heterogeneous graph classification? #3758
-
Hello all, We have been using PyG for a while, mostly on homogeneous graphs for graph classification. We are now expanding the richness of our graph representations to include multiple node types and edge types. We still want to do graph classification but we are finding it hard to put a graph model together. Is there a working example of heterogeneous graph classification in PyG? We took inspiration from the heterogeneous model in the documentation and the notebook for homogeneous graph classification. Below is the code we've tried. However, we get the following error:
However, sometimes we get the following error:
Any help would really be appreciated -even a clean, simple working example of heterogeneous model classification. HGT model and training functions
Code to load up a fake dataset and create the model training cycle:
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
I don't know of any heterogeneous graph classification benchmark dataset, that's why we don't provide an example yet (please ping me if you have any pointers). I haven't run the code since I'm on phone currently (awesome to see you are already utilizing
|
Beta Was this translation helpful? Give feedback.
-
Thanks @rusty1s for your quick response! I think I'm running in circles here. I don't think I want to concatenate in the batch dimension -this would mean that different graphs in the same batch are averaged or influence each other? Concatenating them in the feature dimension is what I think I want (average / sum all values from features on different node types for the same graph). That brings me to a shape of
Progress! After that, I get the following error:
Which seems to indicate the data.y is of type Long. Since the type is set by FakeHeteroDataset(), I could workaround by using data.y.float() instead. This seems to let me train! I'll test with real data and validate a bit further I am aggregating as I expect. Thank you again PS: To make it work, I had to also include the number of node types when generating my linear layers:
|
Beta Was this translation helpful? Give feedback.
-
Could you please provide code examples? I have been doing heterogeneous graph classification recently, but there are no relevant documents. |
Beta Was this translation helpful? Give feedback.
I don't know of any heterogeneous graph classification benchmark dataset, that's why we don't provide an example yet (please ping me if you have any pointers).
I haven't run the code since I'm on phone currently (awesome to see you are already utilizing
FakeHeteroDataset
), but I think the issue is in the way you do "global pooling". IMO, there exists two options here:[num_graphs, hidden_channels]
.