How to specifically train and test with some of nodes in the graph for node classification? Each graph uses different nodes. #6587
-
Hi, I am working on a node classification task. I have a dataset of 1000 graphs, each graph is an 11*7 grid graph. Each node in each graph has three different labels (0,1,2). Only nodes with labels 0 or 1 are interested. Roughly, each graph has 70 nodes with label 2 (which we are not interested in). I want to train a node classification model with only nodes having labels 0 or 1 in each graph, on 800 graphs, and test on the rest 200 graphs. The node of interest in each graph is not consistent (e.g., graph A may have nodes 1 and 2, while in graph B, they are nodes 2,4,5). I want to ask how to build the dataset with this setting and train a (for example) GCN model. For now, I am using a method to add the mask as the following examples:
And test with GCN mode like this:
However, I am having an index out-of-bounds error which can be traced to Can anyone show me the proper way to build the dataset and train the model? Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The problem is in line train_data = graph.__class__(x=graph.x[train_mask], edge_index=graph.edge_index) where you reduce the number of nodes but keep the original If you want to remove nodes, it's better to use |
Beta Was this translation helpful? Give feedback.
The problem is in line
where you reduce the number of nodes but keep the original
edge_index
.If you want to remove nodes, it's better to use
data.subgraph(train_mask)
instead.