Replies: 1 comment 1 reply
-
As far as I understand, the problem is that you do not actually want to split links in an inductive fashion. Instead, you want to make use of all I think in this case, num_edges = data[...].num_edges
perm = torch.randperm(num_edges)
train_idx = perm[:int(0.8 * num_edges)]
val_idx = perm[:int(0.8 * num_edges):int(0.9 * num_edges)]
test_idx = perm[int(0.9 * num_edges):]
# Convert to mask and add to the corresponding edge type.
... |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi PyG community,
I have a HeteroData object which has edge labels. I couldn't find any specific way to split this into train, val test than the RandomLinkSplit since most DataLoaders are node oriented. For example,
and this is the example data that needs splitting
into the train, val, test split.(see the attached images)
In my network , to predict the value of the edge_label, i want to use my edge_attr in the network as a concatenation of the node features. However there are no masks. I'm unsure how I can index the correct edge_index connections in val_data. I'm trying to hack my way through by putting conditionals during the model.eval() and change which part of the predictions i need to consider, but there had to be a better way to mask the validation set, like the case of nodes.
For train , I can see that the excess labels are just 0 and can consider pred[: 93325] vs the target labels. For val and test, the labels seem to be more in number than the edge connections. So for my RMSE loss, i'd need to compare very specific set of those edge label predictions. I cannot create a mask before hand in the HeteroData itself since i'll need to add edge connections for message passing,
Is there a way to create a mask for these , or a better way to split the data into train-val-test based on edge_labels that i'm missing? Perhaps even using some loader. HGT and Neighbor sampling are more for node tasks at the moment.
Beta Was this translation helpful? Give feedback.
All reactions