How to customly mini-batch HeteroData objects? #4772
DanielPerezJensen
started this conversation in
General
Replies: 1 comment 2 replies
-
This should be very similar to do for def __inc__(self, key, value, store, *args, **kwargs):
if key == "edge_indices":
src_type, _, dst_type = store._key
return torch.tensor([[0], [self[src_type.xs[0].size(0)]], [self[dst_type].xs[0].size(0)]])
else:
return super().__inc__(key, value, store, *args, **kwargs) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am following along with the following tutorial about advanced mini batching.
I need to be able to do this because I am working with an LSTM within a graph structure. Basically, I want to use 6 graphs to predict 6 values for 4 of the nodes within a graph. To this end, I need a way to create a graph which contains 6 unconnected graphs. This seems simple enough when only using the normal
Data
for homogeneous graphs. But I am working in a heterogeneous setting.I have 4 nodes for which I want to predict the following 6 values, and I have 11 nodes for which I have information that might be relevant to predicting these values. These types of nodes are different and have different features, also the edges between them are different.
There doesn't seem to be a guide on how to tackle this for the heterogeneous setting. The way I did it in the homogeneous setting is by basically holding a list for edge indices, a list for node attributes and the to be predicted values within a custom Data class called
SeqData
.Output:
Keep in mind I am using random inputs for clarity's sake, so the number of graphs I store here are only 2, instead of the 6 I specified earlier. Also the number of nodes is different.
But how would I now do this with a HeteroData? When I try to do something similar like below:
Output:
SeqDataBatch()
I just get an empty batch, with no indication of what might be wrong. The
__inc__
function doesn't even seem to be called. What else do I need to change to make this work?Beta Was this translation helpful? Give feedback.
All reactions