This repository contains the code used for the experiments of:
"Understanding Pooling in Graph Neural Networks"
D. Grattarola, D. Zambon, F. M. Bianchi, C. Alippi
https://arxiv.org/abs/2110.05292
The dependencies of the project are listed in requirements.txt. You can install them with:
pip install -r requirements.txtThe code to run our experiments is in the following folders:
autoencoder/spectral_similarity/graph_classification/
Each folder has a script called run_all.sh that will reproduce the results reported in the paper.
To generate the plots and tables from the paper, you can use the plots.py, plots_datasets.py, or tables.py scripts in each folder.
To run experiments for an individual pooling operator, you can use the run_[OPERATOR NAME].py scripts in each folder.
The pooling operators that we used for the experiments are in layers/ (trainable) and modules/ (non-trainable).
The GNN architectures used in the experiments are in models/.
The core of this repository is the SRCPool class that implements a general
interface to create SRC pooling layers with the Keras API.
Our implementation of MinCutPool, DiffPool, LaPool, Top-K, and SAGPool using the
SRCPool class can be found in src/layers.
SRC layers have the following structure
where
By extending this class, it is possible to create any pooling layer in the SRC framework.
Input
X: Tensor of shape([batch], N, F)representing node features;A: Tensor or SparseTensor of shape([batch], N, N)representing the adjacency matrix;I: (optional) Tensor of integers with shape(N, )representing the batch index;
Output
X_pool: Tensor of shape([batch], K, F), representing the node features of the output.Kis the number of output nodes and depends on the specific pooling strategy;A_pool: Tensor or SparseTensor of shape([batch], K, K)representing the adjacency matrix of the output;I_pool: (only ifIwas given as input) Tensor of integers with shape(K, )representing the batch index of the output;S_pool: (ifreturn_sel=True) Tensor or SparseTensor representing the supernode assignments;
API
pool(X, A, I, **kwargs): pools the graph and returns the reduced node features and adjacency matrix. If the batch indexIis notNone, a reduced version ofIwill be returned as well. Any givenkwargswill be passed as keyword arguments toselect(),reduce()andconnect()if any matching key is found. The mandatory arguments ofpool()(X,A, andI) must be computed incall()by callingself.get_inputs(inputs).select(X, A, I, **kwargs): computes supernode assignments mapping the nodes of the input graph to the nodes of the output.reduce(X, S, **kwargs): reduces the supernodes to form the nodes of the pooled graph.connect(A, S, **kwargs): connects the reduced supernodes.reduce_index(I, S, **kwargs): helper function to reduce the batch index (only called ifIis given as input).
When overriding any function of the API, it is possible to access the
true number of nodes of the input (N) as a Tensor in the instance variable
self.N (this is populated by self.get_inputs() at the beginning of
call()).
Arguments:
return_sel: ifTrue, the Tensor used to represent supernode assignments will be returned withX_pool,A_pool, andI_pool;
