Skip to content

Commit 1507bda

Browse files
[2.7] Add Brats to research (#4026)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 3b3dd81 commit 1507bda

File tree

18 files changed

+7426
-10
lines changed

18 files changed

+7426
-10
lines changed

research/README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ NVIDIA FLARE has been used in several research studies. In this directory, you c
66

77
## Research Implementations
88

9-
10. [FedNCA - Equitable Federated Learning with NCA](./FedNCA/README.md) ([MICCAI 2025](https://arxiv.org/abs/2506.21735))
10-
09. [FedBPT: Efficient Federated Black-box Prompt Tuning for Large Language Models](./fed-bpt/README.md) ([ICML 2024](https://arxiv.org/abs/2310.01467))
11-
08. [ConDistFL: Conditional Distillation for Federated Learning from Partially Annotated Data](./condist-fl/README.md) ([DeCaF 2023](https://arxiv.org/abs/2308.04070))
12-
07. [FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning](./fedobd/README.md) ([IJCAI 2023](https://arxiv.org/abs/2208.05174))
13-
06. [Fair Federated Medical Image Segmentation via Client Contribution Estimation](./fed-ce/README.md) ([CVPR 2023](https://arxiv.org/abs/2303.16520))
14-
05. [Communication-Efficient Vertical Federated Learning with Limited Overlapping Samples](./one-shot-vfl/README.md) ([ICCV 2023](https://arxiv.org/abs/2303.16270))
15-
04. [Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation](./fed-sm/README.md) ([CVPR 2022](https://arxiv.org/abs/2203.10144))
16-
03. [Do Gradient Inversion Attacks Make Federated Learning Unsafe?](./quantifying-data-leakage/README.md) ([IEEE Transactions on Medical Imaging 2022](https://arxiv.org/abs/2202.06924))
17-
02. [Auto-FedRL: Federated Hyperparameter Optimization for Multi-institutional Medical Image Segmentation](./auto-fed-rl/README.md) ([ECCV 2022](https://arxiv.org/abs/2203.06338))
18-
01. [FedBN: Federated Learning on Non-IID Features via Local Batch Normalization](./fed-bn/README.md) ([ICLR 2021](https://arxiv.org/abs/2102.07623))
9+
1. [FedNCA - Equitable Federated Learning with NCA](./FedNCA/README.md) ([MICCAI 2025](https://arxiv.org/abs/2506.21735))
10+
2. [FedBPT: Efficient Federated Black-box Prompt Tuning for Large Language Models](./fed-bpt/README.md) ([ICML 2024](https://arxiv.org/abs/2310.01467))
11+
3. [ConDistFL: Conditional Distillation for Federated Learning from Partially Annotated Data](./condist-fl/README.md) ([DeCaF 2023](https://arxiv.org/abs/2308.04070))
12+
4. [FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning](./fedobd/README.md) ([IJCAI 2023](https://arxiv.org/abs/2208.05174))
13+
5. [Fair Federated Medical Image Segmentation via Client Contribution Estimation](./fed-ce/README.md) ([CVPR 2023](https://arxiv.org/abs/2303.16520))
14+
6. [Communication-Efficient Vertical Federated Learning with Limited Overlapping Samples](./one-shot-vfl/README.md) ([ICCV 2023](https://arxiv.org/abs/2303.16270))
15+
7. [Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation](./fed-sm/README.md) ([CVPR 2022](https://arxiv.org/abs/2203.10144))
16+
8. [Do Gradient Inversion Attacks Make Federated Learning Unsafe?](./quantifying-data-leakage/README.md) ([IEEE Transactions on Medical Imaging 2022](https://arxiv.org/abs/2202.06924))
17+
9. [Auto-FedRL: Federated Hyperparameter Optimization for Multi-institutional Medical Image Segmentation](./auto-fed-rl/README.md) ([ECCV 2022](https://arxiv.org/abs/2203.06338))
18+
10. [FedBN: Federated Learning on Non-IID Features via Local Batch Normalization](./fed-bn/README.md) ([ICLR 2021](https://arxiv.org/abs/2102.07623))
19+
11. [Privacy-preserving Federated Brain Tumour Segmentation](./brats18/README.md) ([MLMI 2019](https://arxiv.org/abs/1910.00962))
1920

2021
## Contributing
2122

research/brats18/.gitignore

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# BraTS18 dataset training data (keep datalist JSON files, but ignore the actual data)
2+
dataset_brats18/dataset/training
3+
4+
# Compressed medical image files
5+
*.nii.gz
6+
7+
# PyTorch model checkpoints
8+
*.pt
9+
*.pth
10+
11+
# Python cache
12+
__pycache__/
13+
*.pyc

research/brats18/README.md

Lines changed: 333 additions & 0 deletions
Large diffs are not rendered by default.

research/brats18/client.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Client-side training script for BraTS18 using NVFlare Client API.
16+
"""
17+
import argparse
18+
import copy
19+
import os
20+
from typing import Sequence, Tuple
21+
22+
import numpy as np
23+
import torch
24+
import torch.optim as optim
25+
from model import BratsSegResNet
26+
from monai.data import CacheDataset, DataLoader, Dataset, load_decathlon_datalist
27+
from monai.inferers import SlidingWindowInferer
28+
from monai.losses import DiceLoss
29+
from monai.metrics import DiceMetric
30+
from monai.transforms import (
31+
Activations,
32+
AsDiscrete,
33+
Compose,
34+
ConvertToMultiChannelBasedOnBratsClassesd,
35+
DivisiblePadd,
36+
EnsureChannelFirstd,
37+
LoadImaged,
38+
NormalizeIntensityd,
39+
Orientationd,
40+
RandFlipd,
41+
RandScaleIntensityd,
42+
RandShiftIntensityd,
43+
RandSpatialCropd,
44+
Spacingd,
45+
)
46+
47+
import nvflare.client as flare
48+
from nvflare.app_opt.pt.fedproxloss import PTFedProxLoss
49+
from nvflare.client.tracking import SummaryWriter
50+
51+
52+
def parse_args():
53+
parser = argparse.ArgumentParser(description="BraTS18 client training with NVFlare Client API.")
54+
parser.add_argument("--aggregation_epochs", type=int, default=1, help="Local epochs per round.")
55+
parser.add_argument("--learning_rate", type=float, default=1e-4)
56+
parser.add_argument("--fedproxloss_mu", type=float, default=0.0)
57+
parser.add_argument("--cache_dataset", type=float, default=0.0)
58+
parser.add_argument("--dataset_base_dir", type=str, required=True)
59+
parser.add_argument("--datalist_json_path", type=str, required=True)
60+
parser.add_argument(
61+
"--roi_size",
62+
type=int,
63+
nargs=3,
64+
default=(224, 224, 144),
65+
metavar=("X", "Y", "Z"),
66+
)
67+
parser.add_argument(
68+
"--infer_roi_size",
69+
type=int,
70+
nargs=3,
71+
default=(240, 240, 160),
72+
metavar=("X", "Y", "Z"),
73+
)
74+
parser.add_argument("--centralized", action="store_true", help="Use all data for centralized training")
75+
return parser.parse_args()
76+
77+
78+
def custom_client_datalist_json_path(datalist_json_path: str, client_id: str, centralized: bool = False) -> str:
79+
"""Customize datalist_json_path for each client.
80+
81+
Args:
82+
datalist_json_path: Root path containing all json files
83+
client_id: Client identifier (e.g., site-1, site-2, etc.)
84+
centralized: If True, use site-All.json for centralized training with all data
85+
86+
Returns:
87+
Path to the appropriate datalist json file
88+
"""
89+
if centralized:
90+
# Use site-All.json for centralized training with all data
91+
all_data_path = os.path.join(datalist_json_path, "site-All.json")
92+
if os.path.exists(all_data_path):
93+
return all_data_path
94+
return os.path.join(datalist_json_path, client_id + ".json")
95+
96+
97+
def build_dataloaders(
98+
*,
99+
client_id: str,
100+
cache_rate: float,
101+
dataset_base_dir: str,
102+
datalist_json_path: str,
103+
roi_size: Sequence[int],
104+
infer_roi_size: Sequence[int],
105+
centralized: bool = False,
106+
) -> Tuple[DataLoader, DataLoader, SlidingWindowInferer, Compose, DiceMetric]:
107+
datalist_json_path = custom_client_datalist_json_path(datalist_json_path, client_id, centralized)
108+
109+
print(f"[{client_id}] Loading datalist from: {datalist_json_path}")
110+
111+
train_list = load_decathlon_datalist(
112+
data_list_file_path=datalist_json_path,
113+
is_segmentation=True,
114+
data_list_key="training",
115+
base_dir=dataset_base_dir,
116+
)
117+
valid_list = load_decathlon_datalist(
118+
data_list_file_path=datalist_json_path,
119+
is_segmentation=True,
120+
data_list_key="validation",
121+
base_dir=dataset_base_dir,
122+
)
123+
124+
print(f"[{client_id}] Training samples: {len(train_list)}, Validation samples: {len(valid_list)}")
125+
126+
transform_train = Compose(
127+
[
128+
LoadImaged(keys=["image", "label"]),
129+
EnsureChannelFirstd(keys="image"),
130+
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
131+
Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
132+
Orientationd(keys=["image", "label"], axcodes="RAS"),
133+
RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),
134+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
135+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
136+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
137+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
138+
RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
139+
RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
140+
]
141+
)
142+
transform_valid = Compose(
143+
[
144+
LoadImaged(keys=["image", "label"]),
145+
EnsureChannelFirstd(keys="image"),
146+
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
147+
Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
148+
DivisiblePadd(keys=["image", "label"], k=32),
149+
Orientationd(keys=["image", "label"], axcodes="RAS"),
150+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
151+
]
152+
)
153+
154+
if cache_rate > 0.0:
155+
train_dataset = CacheDataset(data=train_list, transform=transform_train, cache_rate=cache_rate, num_workers=1)
156+
valid_dataset = CacheDataset(data=valid_list, transform=transform_valid, cache_rate=cache_rate, num_workers=1)
157+
else:
158+
train_dataset = Dataset(data=train_list, transform=transform_train)
159+
valid_dataset = Dataset(data=valid_list, transform=transform_valid)
160+
161+
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1)
162+
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=1)
163+
164+
inferer = SlidingWindowInferer(roi_size=infer_roi_size, sw_batch_size=1, overlap=0.5)
165+
transform_post = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
166+
valid_metric = DiceMetric(include_background=True, reduction="mean")
167+
168+
return train_loader, valid_loader, inferer, transform_post, valid_metric
169+
170+
171+
def validate(model, valid_loader, inferer, transform_post, valid_metric, device):
172+
model.eval()
173+
with torch.no_grad():
174+
metric = 0.0
175+
ct = 0
176+
for batch_data in valid_loader:
177+
val_images = batch_data["image"].to(device)
178+
val_labels = batch_data["label"].to(device)
179+
val_outputs = inferer(val_images, model)
180+
val_outputs = transform_post(val_outputs)
181+
metric_score = valid_metric(y_pred=val_outputs, y=val_labels)
182+
for sub_region in range(3):
183+
metric_score_single = metric_score[0][sub_region].item()
184+
if not np.isnan(metric_score_single):
185+
metric += metric_score_single
186+
ct += 1
187+
if ct == 0:
188+
raise ValueError("No valid validation metrics computed. Check validation dataset and data preprocessing.")
189+
return metric / ct
190+
191+
192+
def main():
193+
args = parse_args()
194+
195+
flare.init()
196+
sys_info = flare.system_info()
197+
client_name = sys_info["site_name"]
198+
summary_writer = SummaryWriter()
199+
200+
train_loader, valid_loader, inferer, transform_post, valid_metric = build_dataloaders(
201+
client_id=client_name,
202+
cache_rate=args.cache_dataset,
203+
dataset_base_dir=args.dataset_base_dir,
204+
datalist_json_path=args.datalist_json_path,
205+
roi_size=args.roi_size,
206+
infer_roi_size=args.infer_roi_size,
207+
centralized=args.centralized,
208+
)
209+
210+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
211+
model = BratsSegResNet().to(device)
212+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
213+
criterion = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
214+
criterion_prox = PTFedProxLoss(mu=args.fedproxloss_mu) if args.fedproxloss_mu > 0 else None
215+
216+
while flare.is_running():
217+
input_model = flare.receive()
218+
model.load_state_dict(input_model.params, strict=True)
219+
model.to(device)
220+
221+
global_metric = validate(model, valid_loader, inferer, transform_post, valid_metric, device)
222+
summary_writer.add_scalar("val_metric_global_model", global_metric, input_model.current_round)
223+
224+
model_global = None
225+
if args.fedproxloss_mu > 0:
226+
model_global = copy.deepcopy(model)
227+
for param in model_global.parameters():
228+
param.requires_grad = False
229+
230+
steps_per_epoch = len(train_loader)
231+
total_steps = steps_per_epoch * args.aggregation_epochs
232+
233+
for epoch in range(args.aggregation_epochs):
234+
model.train()
235+
running_loss = 0.0
236+
for batch_data in train_loader:
237+
inputs = batch_data["image"].to(device)
238+
labels = batch_data["label"].to(device)
239+
outputs = model(inputs)
240+
loss = criterion(outputs, labels)
241+
if args.fedproxloss_mu > 0:
242+
loss += criterion_prox(model, model_global)
243+
optimizer.zero_grad()
244+
loss.backward()
245+
optimizer.step()
246+
running_loss += loss.item()
247+
248+
if len(train_loader) == 0:
249+
raise ValueError("Training data loader is empty. Check dataset preparation and datalist configuration.")
250+
avg_loss = running_loss / len(train_loader)
251+
global_step = input_model.current_round * total_steps + epoch
252+
summary_writer.add_scalar("train_loss", avg_loss, global_step)
253+
254+
# Send trained model weights (API will compute diff automatically with TransferType.DIFF)
255+
output_model = flare.FLModel(
256+
params=model.cpu().state_dict(),
257+
metrics={"val_dice": global_metric},
258+
meta={"NUM_STEPS_CURRENT_ROUND": total_steps},
259+
)
260+
flare.send(output_model)
261+
262+
263+
if __name__ == "__main__":
264+
main()

0 commit comments

Comments
 (0)