A scalable distributed-memory Graph Neural Network (GNN) training framework using MPI for inter-process communication. Implements GraphSAGE-style message passing with vertex partitioning and halo exchange.
- Distributed Training: Scale across multiple MPI ranks (1, 2, 5+ nodes)
- GraphSAGE Implementation: Efficient graph convolution with mean aggregation
- Vertex Partitioning: Graph partitioned across workers with halo exchange
- MPI Communication: Distributed gradient synchronization using PyTorch distributed (gloo/nccl backends)
- Flexible Backend: Supports both single-node and multi-node execution
- Python 3.8+
- PyTorch 2.0+
- NumPy
- Open Graph Benchmark (OGB)
pip install -r requirements.txtRun training with full performance metrics and accuracy tracking:
# Single node with demo graph (for testing)
USE_DEMO=true NUM_EPOCHS=10 python -u run_benchmark.py
# Single node with real dataset
NUM_EPOCHS=20 python -u run_benchmark.py
# Distributed training (2 nodes)
torchrun --nproc_per_node=2 run_benchmark.pyOutput includes:
- Training/validation/test accuracy
- Epoch timing (forward/backward pass breakdown)
- Memory usage (peak RSS)
- Throughput (samples/second)
- Scaling metrics for distributed runs
Results are saved to ./results/benchmark_TIMESTAMP.json and ./results/benchmark_summary.txt.
For training with accuracy metrics only:
python -u run.pyFor timing/memory metrics without accuracy:
python -u benchmark.pySet the number of processes via environment variables:
export RANK=0
export WORLD_SIZE=2
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=29500
torchrun --nproc_per_node=2 run_benchmark.py| Variable | Description | Default |
|---|---|---|
NUM_EPOCHS |
Number of training epochs | 10 |
PARTITION_METHOD |
Graph partitioning method | hash |
BATCH_SIZE |
Training batch size | 1024 |
LEARNING_RATE |
Learning rate | 0.01 |
HIDDEN_DIM |
Hidden layer dimension | 256 |
NUM_LAYERS |
Number of GNN layers | 2 |
MPI_BACKEND |
PyTorch distributed backend | gloo |
USE_DEMO |
Use demo graph instead of Papers100M | false |
DGNN-MPI/
├── comm/ # MPI communication utilities
│ ├── mpi_utils.py # Distributed primitives
│ ├── gradient_sync.py
│ └── halo_exchange.py
├── data/ # Data loading
│ ├── dataset_loader.py
│ └── batch_sampler.py
├── model/ # GNN models
│ ├── gnn_model.py # GraphSAGENet
│ ├── graphsage.py # GraphSAGE convolution
│ └── distributed_gnn.py
├── partition/ # Graph partitioning
│ ├── graph_partitioner.py
│ ├── partition_loader.py
│ └── halo_builder.py
├── train/ # Training utilities
│ ├── trainer.py
│ ├── evaluator.py
│ └── logger.py
├── results/ # Output directory
│ ├── benchmark_*.json
│ └── benchmark_summary.txt
├── config.py # Configuration
├── run_benchmark.py # Unified benchmark (recommended)
├── run.py # Training entry point
└── benchmark.py # Performance benchmark
The framework uses a vertex partitioning approach:
- Each MPI rank holds a graph partition (subset of vertices + edges)
- Local node features are stored per rank
- Forward pass is computed locally
- Boundary node embeddings are exchanged via halo exchange
- Backward pass computes local gradients
- Gradients are synchronized across ranks using MPI all-reduce
Default dataset: [OGBN-Papers100M](https://ogb.stanford.edu/#leader datasets)
- ~111M nodes
- ~1.6B edges
- 172 classes
- 128-dimensional features
A smaller demo graph is available for testing.
MIT