FedMAP (Federated Maximum A Posteriori) is a novel federated learning (FL) algorithm that incorporates a global prior distribution over local model parameters using Input Convex Neural Networks (ICNNs), enabling personalized federated learning. This repository contains the complete implementation of the FedMAP algorithm with support for multiple healthcare datasets and tasks.
The FedMAP algorithm consists of three main steps:
-
Initialisation: A client is randomly selected, and its model parameters are used to initialise the global model and the local model parameters for all clients.
-
Local Optimisation: Each client optimises their model parameters by minimizing the negative log-likelihood of the posterior distribution, which includes an ICNN-based prior term that penalizes deviations from the global model parameters.
-
Global Aggregation: The server aggregates the optimized local model parameters from all clients using a weighted average, where the weights are the contribution scores computed during local optimisation. The updated global model and ICNN modules are then broadcast to all clients for the next round.
- ICNN-based Prior: Adaptive prior distribution using Input Convex Neural Networks
- Multiple Task Support: Classification, binary classification, and survival analysis (Cox Proportional Hazards)
- Healthcare Applications: Pre-configured for INTERVAL, eICU, CPRD, and synthetic datasets
- Three-Tier Deployment: Support for training, fine-tuning, and inference workflows
- Flexible Architecture: Modular design supporting various neural network architectures
FedMAP/
├── 📂 src/
│ ├── 📂 client/ # Client-side federated learning logic
│ │ └── client_app.py # Client training and evaluation handlers
│ ├── 📂 server/ # Server-side aggregation and orchestration
│ │ └── server_app.py # Server initialisation and FedMAP strategy
│ ├── 📂 strategies/ # Federated learning aggregation strategies
│ │ └── fedmap.py # FedMAP strategy with ICNN training
│ ├── 📂 models/ # Neural network architectures
│ │ ├── iron_classifier.py # Dense classifier for INTERVAL
│ │ ├── multimodal_ffn.py # Multi-modal network for eICU
│ │ ├── example_classifier.py # Simple MLP for examples
│ │ └── __init__.py
│ ├── 📂 loss_modules/ # Loss functions and priors
│ │ └── map.py # FedMAP loss with ICNN prior
│ ├── 📂 tasks/ # Task-specific implementations
│ │ ├── interval.py # Iron deficiency classification (INTERVAL)
│ │ ├── eicu.py # ICU mortality prediction (eICU)
│ │ ├── cprd.py # CVD risk prediction (CPRD)
│ │ └── example.py # Synthetic data example
│ ├── 📂 tiers/ # Multi-tier deployment scripts
│ │ ├── tier2_finetune.py # Fine-tuning on new clients
│ │ └── tier3_infer.py # Inference on unseen clients
│ ├── 📂 utils/ # Utility functions
│ │ └── train_helper.py # ICNN implementation and training utilities
│ └── 📂 checkpoints/ # Model checkpoints (created during training)
├── 📂 config/ # Configuration files
│ └── 📂 task/
│ ├── interval.toml # INTERVAL task configuration
│ ├── eicu.toml # eICU task configuration
│ ├── cprd.toml # CPRD task configuration
│ └── example.toml # Example task configuration
├── 📂 datasets/ # Data directory (not included in repo)
│ ├── interval/
│ ├── eicu/
│ ├── cprd/
│ └── example/
├── 📂 scripts/ # Execution scripts
│ ├── run_t1.sh # Run Tier 1 (training)
│ ├── run_t2.sh # Run Tier 2 (fine-tuning)
│ └── run_t3.sh # Run Tier 3 (inference)
├── 📂 results/ # Output metrics (created during execution)
├── Dockerfile # Docker container configuration
├── docker-compose.yml # Docker Compose setup
├── pyproject.toml # Project metadata and dependencies
├── requirements.txt # Python dependencies
└── README.md
- Docker and Docker Compose
- NVIDIA GPU with CUDA support (optional but recommended)
- NVIDIA Container Toolkit (for GPU support)
-
Clone the repository
git clone <repository-url> cd FedMAP
-
Build and start the Docker container
docker compose up --build
This will:
- Build a Docker image based on PyTorch 2.3.0 with CUDA 12.1
- Install Python 3.10 and all required dependencies
- Mount the current directory to
/appin the container - Enable GPU support (if available)
-
Access the container
Open a new terminal and run:
docker exec -it fedmap_container bash
-
Clone the repository
git clone <repository-url> cd FedMAP
-
Create a virtual environment (Python 3.10)
python3.10 -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
-
Install dependencies
pip install --upgrade pip pip install -e .Or if using requirements.txt:
pip install -r requirements.txt
The project requires the following main packages:
flwr[simulation]>=1.8.0, <2.0- Flower federated learning frameworktorch>=2.3.0- PyTorch deep learning frameworkhydra-core==1.3.0- Configuration managementnumpy==2.2.6- Numerical computingpandas==2.3.0- Data manipulationscikit-learn==1.7.0- Machine learning utilitiespycox==0.3.0- Survival analysis (for CPRD task)torchtuples==0.2.2- PyTorch utilities
Before running experiments, you need to prepare your datasets. Place your data in the datasets/ directory with the following structure:
datasets/interval/
├── INTERVAL_irondef_site_1_train.csv
├── INTERVAL_irondef_site_1_val.csv
├── INTERVAL_irondef_site_2_train.csv
├── INTERVAL_irondef_site_2_val.csv
├── INTERVAL_irondef_site_3_train.csv
└── INTERVAL_irondef_site_3_val.csv
datasets/eicu/
├── {hospital_id}/
│ ├── mortality_train.csv
│ ├── mortality_val.csv
│ ├── medications_train.csv
│ ├── medications_val.csv
│ ├── diagnosis_train.csv
│ ├── diagnosis_val.csv
│ ├── physio_train.csv
│ └── physio_val.csv
datasets/cprd/
└── risk_factors_all.csv
datasets/example/
├── partition_0_train.csv
├── partition_0_test.csv
├── partition_1_train.csv
├── partition_1_test.csv
└── ...
Train the global model with FedMAP across multiple clients:
Using Flower CLI:
flwr run . --run-config="./config/task/interval.toml"Using shell script:
bash scripts/run_t1.shConfiguration options (in config/task/[task_name].toml):
task-name: Task identifier (interval, eicu, cprd, example)model: Model architecture to usenum-server-rounds: Number of federated learning roundslocal-epochs: Number of local training epochs per roundfraction-evaluate: Fraction of clients to use for evaluationlearning-rate: Learning rate for local optimisationbatch-size: Batch size for training
Fine-tune the trained global model on new clients with ICNN prior:
bash scripts/run_t2.shThis will:
- Load the trained global model from
src/checkpoints/global_model_{task_name}.pt - Load the trained ICNN modules from
src/checkpoints/icnn_modules.pt - Fine-tune on clients
- Save metrics to
results/{task_name}_metrics_test.csv
Evaluate the global model on completely new clients without fine-tuning:
bash scripts/run_t3.shThis will:
- Load the trained global model
- Evaluate on clients
- Save metrics to
results/{task_name}_metrics_test.csv
By default, the repository is configured to run the example task with synthetic data. You can get started immediately without any additional setup:
Tier 1 - Federated Training (Clients 0-19):
bash scripts/run_t1.shThis trains the global model and ICNN prior modules, then automatically generates performance plots.
Tier 2 - Fine-tuning (Clients 20-34):
bash scripts/run_t2.shThis fine-tunes the global model on new clients using the trained ICNN prior, then generates comparison plots.
Tier 3 - Inference (Clients 35-44):
bash scripts/run_t3.shThis evaluates the global model on completely unseen clients without any fine-tuning, then generates performance plots.
Performance metrics are automatically saved to CSV files in the results/ directory:
Tier 1 (Training):
results/{task_name}_metrics_test_tier1.csv- Validation metrics per round for all Tier 1 clients
Tier 2 (Fine-tuning):
results/{task_name}_metrics_test_tier2.csv- Performance metrics for Tier 2 clients
Tier 3 (Inference):
results/{task_name}_metrics_test_tier3.csv- Performance metrics for Tier 3 clients
Performance plots are automatically generated and saved to results/performance_plots/:
Tier 1 plots (generated automatically):
Accuracy_over_rounds_tier1.png- Accuracy progression across communication roundsBalanced_Accuracy_over_rounds_tier1.png- Balanced accuracy progressionROC_AUC_over_rounds_tier1.png- ROC AUC progressionLoss_over_rounds_tier1.png- Loss progression
Each plot includes:
- Individual client performance curves
- Global average line
- Legend with all clients
Tier 2 plots & Tier 3 plots:
balanced_accuracy_tier2.png- Bar chart of balanced accuracyroc_auc_tier2.png- Bar chart of ROC AUC
The FedMAP algorithm involves each client
Where:
-
$\mathcal{L}(\theta; Z_{k})$ : The local data loss (negative log-likelihood) on the local dataset$Z_k$ . -
$\mathcal{R}(\theta; \mu, \psi)$ : The ICNN-based prior term, which regularises the local parameters$\theta$ . -
$\theta$ : Local model parameters. -
$\mu$ : Global model parameters. -
$\psi$ : ICNN (prior) parameters. -
$N_k$ : The number of data points at site$k$ .
Each client's contribution weight
This quantifies how well the local model fits its data (the likelihood $\mathbb{P}(Z_{k} | \theta_{k})$) while adhering to the global prior (the
The ICNN (Input Convex Neural Network) provides a learned, convex prior. As defined in Equation 6 of the paper, the regulariser
Where
- Model: DenseClassifier (2-layer dense network)
- Task: Binary classification of iron deficiency
- Features: 18 features (16 haematology + age + sex)
- Model: MultimodalFFN (multi-modal feedforward network)
- Task: Binary classification of ICU mortality
- Modalities: Medications (1411), Diagnosis (686), Physiology (7)
- Model: CoxPH (Cox Proportional Hazards)
- Task: Survival analysis for cardiovascular disease risk
- Features: 8 risk factors (age, sex, SBP, cholesterol, etc.)
- Model: MLP (3-layer multilayer perceptron)
- Task: Binary classification
- Features: 31 synthetic features
Training saves the following checkpoints:
src/checkpoints/global_model_{task_name}.pt- Final global modelsrc/checkpoints/icnn_modules.pt- Trained ICNN prior modules
Performance metrics are saved to CSV files:
results/{task_name}_metrics_test.csv- Test/validation metrics per roundresults/{task_name}_metrics_train.csv- Training metrics (if logged)
Metrics include:
- Loss, Accuracy, Balanced Accuracy
- ROC AUC, AUPRC (for classification tasks)
- Concordance Index, Integrated Brier Score (for survival tasks)
- Confusion matrix (TP, TN, FP, FN)
-
Create a task class in
src/tasks/your_task.py:class YourTask: def __init__(self, cid, config, device): # Initialize task pass def set_models(self, global_model, cnnet_modules): # Set up models pass def train(self, patience=3, batch_size=32): # Training logic return best_state_dict, contribution def validate(self, batch_size=32): # Validation logic return avg_loss, metrics
-
Add configuration in
config/task/your_task.toml:task-name = "your_task" model = "YourModel" num-server-rounds = 10 local-epochs = 5
-
Update imports in
src/client/client_app.pyandsrc/server/server_app.py
-
Create your model in
src/models/your_model.py:import torch.nn as nn class YourModel(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() # Define layers def forward(self, x): # Forward pass return output
-
Import it in
src/models/__init__.py -
Add initialisation logic in
src/server/server_app.py
The Docker setup automatically enables GPU support if available. To verify GPU access:
docker exec -it fedmap_container python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"To adjust GPU resources, modify docker-compose.yml:
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all # or specify number
capabilities: [gpu]-
Out of Memory (OOM) errors
- Reduce batch size in configuration
- Reduce model size or hidden dimensions
- Increase shared memory:
shm_size: 100gbin docker-compose.yml
-
Dataset not found errors
- Verify dataset paths in task files
- Ensure data is properly mounted (check docker-compose.yml volumes)
- Check file permissions
-
Client evaluation failures
- Check for empty validation sets
- Verify class balance in datasets
- Ensure proper data preprocessing
If you use FedMAP in your research, please cite:
@misc{zhang2025fedmappersonalisedfederatedlearning,
title={FedMAP: Personalised Federated Learning for Real Large-Scale Healthcare Systems},
author={Fan Zhang and Daniel Kreuter and Carlos Esteve-Yagüe and Sören Dittmer and Javier Fernandez-Marques and Samantha Ip and BloodCounts! Consortium and Norbert C. J. de Wit and Angela Wood and James HF Rudd and Nicholas Lane and Nicholas S Gleadall and Carola-Bibiane Schönlieb and Michael Roberts},
year={2025},
eprint={2405.19000},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2405.19000}
}
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Contributions are welcome! Please feel free to submit a Pull Request.
For questions or issues, please open an issue on GitHub or contact the maintainers.
