Code and report for my semester project on using rotational and translational equivariant graph neural networks to predict cardiac arrest from 3 dimensional reconstructed arteries. For more information about the experiments, please check out my write up!
The way I recommend navigating this repo is to:
- Look at the write up
- Look at the code structure below
- Look at the experiments results on wandb platform. For more information on how to navigate that, see the appendix of the write-up.
- If you want to run your own results. You will need an MI-proj/data folder, containing a patient_dict.pickle file as described below in datasets.py description, and a data folder, for example CoordToCnc with your mesh data.
- Generate your experiments by running python main_cross_val with appropriate hyperparameters in hyper_params.yaml.
Name | Description
----------------------------------------------------------------------------------------------------------
create_data.py | Data fetching and preprocessing. This should be run from the MI-proj directory.
| The executed function is at the bottom of the file, note that our dataset is not
| public, so you won't have access to the path and label_path directories.
data_augmentation.py | Contains all the data augmentation schemes attempted. Used in create_data.py.
datasets.py | Contains our custom DataSet object which is how we store the meshes. Also contains
| custom split_data function which does the train, validation, and test splits at
| the patient level. Note that you will need a file "MI-proj/data/patient_dict.pickle"
| containing the dictionary with patients as keys and artery name list as value.
hyper_params.yaml | File containing all hyperparameters of a given model. Used in evaluate.py and
| main_cross_val.py. If you plan on using it for evaluate.py, there should be one
| value per hyperparameter.
main_cross_val.py | Runs a grid search with cross validation on all combinations of hyperparameters in
| hyper_params.yaml. All experiments are recorded on the wandb platform. Make sure to
| change and remmember MODEL_TYPE to be able to retrieve the experiment on the wandb
| platform! This does not use the test set. This should be called from inside the
| MI-proj/experiments directory.
evaluate. py | Same as cross_validation, but evaluates the model on test set once it has finished
| training. This should be run with only one value per hyperparameter in
| hyper_params.yaml. It is crucial to use the same seed here as used when doing the
| grid search. Also records all results on the wandb platform. This should be called
| from inside the MI-proj/experiments directory.
gnnexplainer.ipynb | Coming soon! Jupyter notebook for the GNNExplainer experiment and visualization.
GNNExplainer.py | Slightly modified code from the paper of [1].
| Code was obtained from the repo of [1].
egnn.py | Slightly modified code from the paper of [2].
| Code was obtained from the repo of [2].
models.py | Contains all different models used in experiments.
train.py | Contains a custom GNN object definition. Main script used for training and
| evaluating our models.
[1] paper: GNNExplainer: Generating Explanations for Graph Neural Networks, repo: GNNExplainer.
[2] paper: E(n) Equivariant Graph Neural Networks, repo: egnn.