-
I am playing around with the data and code provided by this paper, titled "Utilizing Molecular Network Information via Graph Convolutional Neural Networks to Predict Metastatic Event in Breast Cancer." (this is not important, but the code is actually provided by a follow up paper to the one mentioned here. Just thought I'd mention it to avoid confusion). Basically, the authors have gene expression data for cancer patients that exhibited / did not exhibit metastasis after a period of 5 years and trained a GNN model to classify them as such. The idea is to "project" the expression data on a gene interaction graph (such as a PPI graph) and to associate a graph with each patient. Each node of a given graph represents a gene and contains the expression value of that gene for the patient represented by that graph. Now, the code the authors used is mostly adapted from the "famous" and "old" paper by Defferard et al 2016, "Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering." Defferard provides the code for that paper here. The code by Defferard is "old" in the sense that it uses Tensorflow v1, and also does not use any graph libraries since I don't think there were any at that time. So I wanted to try to train a model on the dataset using newer code / libraries like Pytorch geometric (and Pytorch obviously). Now, the models that I built / used from more recent papers are all facing the same problem: they keep predicting the majority class (PS: there is no class imbalance, it's like 59%, 41%). So I keep getting an accuracy of 59%. What baffles me is that Defferard's model is able to learn and actually gets around 82% accuracy.
Result: predicts majority class only Also, I noticed that Defferard's model is the only one that uses coarsening + pooling between layers, which it does via Graclus. So I tried adding that to Deeper GCN's Res+ block (commented out anything that has to do with edge features, since the data does not have edge features in my case):
Again, no difference. I'm quite confident the problem does not have to do with the way I input the data / data processing because I checked that multiple times. Of course, I tried playing around with hyperparameters as well, especially the learning rate, but nothing came out of that. I'm kinda stumped at this point. Was wondering if anyone has any pointers. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is hard to say actually. I see that the TF Code you linked also performs regularization and exponential LR decay. It also transforms input node features before inputting them to the GNN. Furthermore, this might also be induced by different weight initializations. Sadly, I cannot give you clear guidance on what might be the cause of this :( |
Beta Was this translation helpful? Give feedback.
This is hard to say actually. I see that the TF Code you linked also performs regularization and exponential LR decay. It also transforms input node features before inputting them to the GNN. Furthermore, this might also be induced by different weight initializations.
Sadly, I cannot give you clear guidance on what might be the cause of this :(