Skip to content

collaborativebioinformatics/MuFFLe

Repository files navigation

Multimodal Framework for Federated Learning (MuFFLe)

MuFFLe Logo

Contributors

  1. Tyler Yang
  2. Shreyan Nalwad
  3. Varalakshmi Perumal (https://orcid.org/0000-0002-8293-2308)
  4. Anna Boeva
  5. James Mu
  6. Amrit Gaire
  7. Zhenghao (Chester) Xiao
  8. Yiman Wu
  9. Jon Moller
  10. Mahtabin Rodela
  11. Yajushi Khurana

flowchart

Quick Start

The Memphis/San-Diego example workflow is contained in src/. We will provide instructions for .venv/ setup; you can use conda, PyEnv, or any other Python environment manager if you'd like.

python3 -m venv .venv       # create your environment
source .venv/bin/activate   # activate it
pip install "nvflare[PT]" torchinfo tensorboard matplotlib jupyter ipykernel # install necessary packages

Now we need to get the RNA-sequencing data and clinical data from CHIMERA Task 3. Make sure you have AWS CLI installed (e.g., via Homebrew on MacOS).

# List files
aws s3 ls --no-sign-request s3://chimera-challenge/v2/task3/data/3A_001/
# Copy all the Clinical Data and RNASeq Data
aws s3 cp --no-sign-request s3://chimera-challenge/v2/task3/data/ local/path/to/data/ --recursive --exclude "*" --include "*.json"

Now go into src/multi-client-sim.py and change the DATASET_PATH variable to wherever you downloaded the data.

You can now run the jupyter notebook src/prf-of-concept.ipynb!

Logs for tensorboard are stored in /tmp/nvflare/simulation/MEM_SAN_FedCollab/server/simulate_job/tb_events/. More instructions are in the jupyter notebook src/prf-of-concept.ipynb.

Introduction

MuFFLe is a privacy-preserving framework for integrating multimodal biomedical data (RNA sequencing, clinical features) for cancer prognosis. Using NVIDIA's NVFlare, each hospital site trains on its local data and shares only model updates—not raw patient data—with a central server for aggregation.

Cancer prognosis models require multimodal data (imaging, RNA-seq, clinical variables) across institutions, but data sharing is restricted due to privacy, regulatory, and institutional barriers. Integrating transcriptomics with clinical features improves prognostic performance, but most hospitals cannot pool raw patient data across sites. Centralized training is often infeasible due to HIPAA constraints, motivating a federated learning approach where data remains local.

Using NVIDIA’s NVFlare, each hospital trains locally on its multimodal data and shares only encrypted model updates with a central server, enabling global model learning while preserving patient privacy.

Methods + Discussion

Training uses NVFlare's FedAvg algorithm across simulated sites, where each site specializes in one modality (e.g., Site-1 trains on clinical data, Site-2 on RNA). Sites receive the global model, train locally, and send weight updates back for aggregation—enabling collaborative learning while preserving privacy.

Example Scenario

At a high level, we wanted to simulate a real-world scenario where federated learning and multimodal frameworks can mitigate data gaps stemming from unequal access to healthcare.

We motivate this workflow using a hypothetical scenario: Consider we have 2 hospitals, one in San Diego and one in Memphis. Assume that, since San Diego is a larger city, (1) they will have more patients and (2) have more data modalities available per patient compared to Memphis.

We want to use federated learning to train a multimodal model that predicts tumor progression outcomes. We formalize this task as an instance of binary classification, 1 if the tumor progresses, and 0 otherwise.

Initial Scenario

Federated Learning Setup for MEM SAN simulation, before At first, both hospitals have clinical data, but San Diego Hospital has more healthcare capabilities and can provide RNA sequencing data as well. This additional modality would be helpful for tumor progression prediction, but there is a data gap between different data sources.

To fix this, the multimodal fusion model architecture that we propose in MuFFLe is able to selectively turn off the RNASeq Encoder and zero out the embeddings, effectively telling the model to pay no attention to RNASeq on the Memphis client side.

Strengths of our Framework and Architecture

In short: extensibility and interpretability.

Extensibility: It should be noted that there is additional flexibility that can be coordinated between hospitals; for example, if San Diego knows that they have more data and more modalities, they may opt to tune the model for more epochs and use a bigger batch size than the Memphis client. Furthermore, using different federated aggregation algorithms beyond FedAvg can give each client's results the appropriate importance. That's beyond the scope of this hackathon, but is easily extensible as long as you create a similar recipe as NVFlare's PyTorch FedAvgRecipe.

Interpretability: Furthermore, the learnable per-modality weights are very easily interpreted as "importance scores" for each modality towards predicting tumor progression. The attention layer can also show which parts of the concatenated embedding are most important, and which features within each sub-embedding are important. For example, if the attention scores are high in the RNASeq sub-embedding at specific positions, that corresponds to high importance for specific protein-coding genes towards predicting tumor prediction.

In the proof-of-concept, we run simulate federated learning on a 42M-parameter model for a few rounds just to show that the MuFFLe framework works.

Later, Extending the Framework via Continued Learning

Federated Learning Setup for MEM SAN simulation, after Let's say that later on, the Memphis Hospital begins RNA Sequencing their patients. The multimodal framework proposed in MuFFLe is easily extended: the Memphis Client can just begin sending its new RNASeq data through the model, so now all data sources are using all data modalities to train the global fusion model.

How we built this simulated scenario

We first chose an existing multimodal dataset so we could perform supervised learning at the client-level and perform evaluation at the server level. We chose the data from the Combining HIstology, Medical imaging (radiology) and molEcular data for medical pRognosis and diAgnosis (CHIMERA) Challenge.

Our task and data is adapted from both CHIMERA Task 2: Bcg Response Subtype Prediction In High-Risk NMIBC, and Task 3: Bladder Cancer Recurrence Prediction. We take the clinical data and RNA-sequencing data from Task 3, and take inspiration from the baseline model and evaluation objective from Task 2.

We split the data such that San Diego would have twice the number of data points as Memphis, and we held out a small subset of the data for evaluation. More details in the manuscript and src/README.md.

Future Directions

There are some low-hanging fruit that this could be applied to. While searching for instances to create our proof-of-concept, we came across some data from the Real-time Analysis and Discovery in Integrated And Networked Technologies (RADIANT) group, which

seeks to develop an extensible, federated framework for rapid exchange of multimodal clinical and research data on behalf of accelerated discovery and patient impact. RADIANT Public Data (AWS).

We elected not to use this dataset because the S3 bucket had "controlled access," which required filling out a form for approval and did not fit the fast-paced nature of the Hackathon. However, our federated learning framework could be easily extended to RADIANT's data, which contains

Clinical data, Imaging data, Histology data, Genomic data, Proteomics data, and more Children's Brain Tumor Network (CBTN).

Vertical FL


CHIMERA Dataset Analysis

Clustering Results and Clinical Validation

The pipeline successfully stratified all 176 patients into 3 distinct risk clusters (Cluster 0: 53 patients, Cluster 1: 72 patients, Cluster 2: 51 patients). To validate whether these clusters capture clinically meaningful patterns relevant to bladder cancer recurrence, we compared cluster assignments against established clinical risk factors and survival outcomes.

Clinical Relevance Summary

Statistical validation showing associations between clusters and clinical variables including progression rates, BRS categories, and demographic factors.

Scientific Question: Do the multimodal (WSI + RNA) clusters identify patient subgroups with distinct recurrence risk profiles?

Key Findings:

  • Differential Progression Rates: Clusters show varying progression rates, with Cluster 2 exhibiting the highest risk
  • BRS Association: Clusters align with Bladder Recurrence Score categories, validating biological relevance
  • Clinical Variables: Statistical tests (chi-square, ANOVA) reveal associations with established prognostic factors

Interpretability: Attention Heatmaps

A key advantage of our heuristic-based approach is interpretability—we can directly visualize which tissue regions drive clustering decisions. The gated attention mechanism weights patches based on morphological complexity (variance), allowing us to generate spatial heatmaps showing where the model focuses its attention.

Attention Heatmap

Spatial attention heatmap overlayed on a whole-slide image. Warm colors (red/yellow) indicate high-attention patches that contributed most to the patient's slide embedding. These regions typically correspond to morphologically complex areas such as tumor nests or regions with high cellular pleomorphism.

The attention mechanism enables biological validation: high-attention patches should localize to tumor regions with significant morphological features, not background stroma or artifacts. This interpretability is crucial for clinical adoption, as pathologists can verify that the model focuses on biologically relevant tissue patterns.

Survival Analysis: Kaplan-Meier Curves

Scientific Question: Do the clusters predict recurrence-free survival?

Kaplan-Meier survival analysis evaluates whether patients in different clusters experience different recurrence risk over time. The survival curves below show the probability of remaining recurrence-free for each cluster.

Kaplan-Meier Survival Curves

Kaplan-Meier curves showing recurrence-free survival probability over time for each cluster. The separation between curves indicates differential recurrence risk. While the log-rank test p-value (0.3069) did not reach statistical significance in this unsupervised setting, the visual separation suggests distinct risk profiles. C-index: 0.5507.

Interpretation: The curves show visual separation between clusters, suggesting distinct recurrence risk profiles. Cluster 2 shows the fastest decline (highest risk), while Cluster 0 appears more favorable. The modest C-index (0.5507) and non-significant p-value are expected for unsupervised clustering without labeled training data, but the pattern suggests the multimodal approach captures meaningful prognostic information that warrants further investigation with larger cohorts.

t-SNE Visualization

t-SNE projection of 1280-dimensional multimodal patient signatures colored by cluster assignment, demonstrating spatial separation of clusters in the fused feature space.

Key Advantages

This heuristic-based approach offers several benefits:

  • High Interpretability: Attention heatmaps reveal which tissue regions drive clustering, enabling biological validation
  • No Training Required: Works immediately on new data using fixed mathematical operations
  • Immediate Deployment: No model training or fine-tuning needed
  • Robustness: Avoids overfitting common in deep learning models on small datasets

The full implementation, including attention heatmap visualization and survival analysis tools, is available in the Fusion_model_clustering/ directory of this repository.

About

Knowledge structures for Multimodal Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 11