Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 53 additions & 331 deletions research/fedumm/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,29 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Centralized (non-FL) training baseline for any registered VLM."""
"""Centralized (non-FL) training baseline for BLIP-VQA."""

import argparse
import os

import src.blip_backend # noqa: F401 — registers the blip_vqa backend
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

import sys, os as _os # noqa: F401 - triggers backend registration
sys.path.insert(0, _os.path.dirname(_os.path.abspath(__file__)))
import src
from src.common import (
count_trainable_params, maybe_subsample, set_seed, train_one_epoch,
)
from src.common import count_trainable_params, maybe_subsample, set_seed, train_one_epoch
from src.model_registry import get_backend
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--model_backend", type=str, required=True)
p.add_argument("--model_name_or_path", type=str, default="")
p.add_argument("--output_dir", type=str, default="./workspace_centralized")
p.add_argument("--output_dir", type=str, default="/tmp/nvflare/workspaces/fedumm/centralized")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--max_train_samples", type=int, default=-1)
p.add_argument("--max_eval_samples", type=int, default=-1)
Expand All @@ -53,17 +48,16 @@ def main() -> None:

set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
cache_dir = args.data_path or os.environ.get("HF_HOME", "/tmp/hf_cache")
backend = get_backend(args.model_backend)
writer = SummaryWriter(log_dir=os.path.join(args.output_dir, "tb_logs"))
cache_dir = args.data_path or None
backend = get_backend("blip_vqa")

import aiohttp
_timeout = {'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}}
train_hf = load_dataset(backend.hf_dataset_name(),
split=backend.hf_train_split(), cache_dir=cache_dir,
storage_options=_timeout)
eval_hf = load_dataset(backend.hf_dataset_name(),
split=backend.hf_eval_split(), cache_dir=cache_dir,
storage_options=_timeout)
train_hf = load_dataset(
backend.hf_dataset_name(), split=backend.hf_train_split(), cache_dir=cache_dir, trust_remote_code=True
)
eval_hf = load_dataset(
backend.hf_dataset_name(), split=backend.hf_eval_split(), cache_dir=cache_dir, trust_remote_code=True
)
keep = set(backend.keep_columns())
train_hf = train_hf.remove_columns([c for c in train_hf.column_names if c not in keep])
eval_hf = eval_hf.remove_columns([c for c in eval_hf.column_names if c not in keep])
Expand All @@ -73,28 +67,52 @@ def main() -> None:

device = "cuda" if torch.cuda.is_available() else "cpu"
model, processor = backend.build_model_and_processor(
args.model_name_or_path, args.lora_r, args.lora_alpha,
args.lora_dropout, device)
args.model_name_or_path, args.lora_r, args.lora_alpha, args.lora_dropout, device
)
print(count_trainable_params(model))

train_ds = backend.build_dataset(train_hf, processor, args.max_q_len, args.max_a_len)
eval_ds = backend.build_dataset(eval_hf, processor, args.max_q_len, args.max_a_len)
train_ld = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers,
collate_fn=backend.collate_fn, pin_memory=True)
eval_ld = DataLoader(eval_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers,
collate_fn=backend.collate_fn, pin_memory=True)
train_ld = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=backend.collate_fn,
pin_memory=True,
)
eval_ld = DataLoader(
eval_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=backend.collate_fn,
pin_memory=True,
)

optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad), lr=args.lr)
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=args.lr)

steps_per_epoch = len(train_ld)
for epoch in range(args.num_epochs):
loss = train_one_epoch(model, train_ld, optimizer, device, args.grad_accum, backend)
print(f"Epoch {epoch + 1}/{args.num_epochs}", flush=True)
loss = train_one_epoch(
model,
train_ld,
optimizer,
device,
args.grad_accum,
backend,
prefix=f"epoch={epoch + 1}",
writer=writer,
global_step_offset=epoch * steps_per_epoch,
)
acc = backend.evaluate(model, eval_ld, processor, device)
print(f"Epoch {epoch+1}/{args.num_epochs} loss={loss:.4f} acc={acc:.4f}")
writer.add_scalar("val/acc", acc, epoch + 1)
print(f"Epoch {epoch + 1}/{args.num_epochs} loss={loss:.4f} acc={acc:.4f}")

model.save_pretrained(args.output_dir)
model.text_encoder.save_pretrained(os.path.join(args.output_dir, "text_encoder"))
model.text_decoder.save_pretrained(os.path.join(args.output_dir, "text_decoder"))
writer.close()
print(f"Saved to {args.output_dir}")


Expand Down
164 changes: 91 additions & 73 deletions research/fedumm/src/fl_client.py → research/fedumm/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,29 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified NVFlare FL client for any registered VLM backend.

Select the model at launch time::

python fl_client.py --model_backend blip_vqa ...
python fl_client.py --model_backend januspro ...

The SubprocessLauncher in the NVFlare job config passes ``--model_backend``
via ``script_args``, so different sites can even run *different* models
(though typically all sites use the same one for FedAvg to make sense).
"""
"""NVFlare FL client for federated BLIP-VQA fine-tuning."""

import argparse
import os

import src.blip_backend # noqa: F401 — triggers backend registration
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

import nvflare.client as flare

# This import triggers backend registration
import src # noqa: F401
from src.common import (
count_trainable_params,
get_trainable_params,
Expand All @@ -45,6 +29,10 @@
train_one_epoch,
)
from src.model_registry import get_backend
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import nvflare.client as flare


def parse_site_id(site_name: str) -> int:
Expand All @@ -56,10 +44,7 @@ def parse_site_id(site_name: str) -> int:

def _parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model_backend", type=str, required=True,
help="Registry key: blip_vqa | januspro")
p.add_argument("--model_name_or_path", type=str, default="",
help="HF model id (uses backend default if empty).")
p.add_argument("--model_name_or_path", type=str, default="", help="HF model id (uses backend default if empty).")
p.add_argument("--num_clients", type=int, default=2)
p.add_argument("--local_epochs", type=int, default=1)
p.add_argument("--batch_size", type=int, default=8)
Expand All @@ -74,66 +59,78 @@ def _parse_args():
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--lora_dropout", type=float, default=0.1)
p.add_argument("--data_path", type=str, default="")
p.add_argument("--dirichlet_alpha", type=float, default=0.0,
help="Dirichlet concentration for non-IID partition. "
"0 = IID round-robin, 0.1 = extreme non-IID, "
"0.5 = moderate, 1.0 = mild. (Paper: 0.1/0.5/1.0)")
p.add_argument(
"--dirichlet_alpha",
type=float,
default=0.0,
help="Dirichlet concentration for non-IID partition. "
"0 = IID round-robin, 0.1/0.5/1.0 = non-IID levels from paper.",
)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()


def main() -> None:
args = _parse_args()
set_seed(args.seed)
cache_dir = args.data_path or os.environ.get("HF_HOME", "/tmp/hf_cache")
cache_dir = args.data_path or None

backend = get_backend(args.model_backend)
backend = get_backend("blip_vqa")
print(f">>> Backend: {backend.name}", flush=True)

# ---- Data ----
print(">>> Loading dataset ...", flush=True)
train_hf = load_dataset(backend.hf_dataset_name(),
split=backend.hf_train_split(), cache_dir=cache_dir)
eval_hf = load_dataset(backend.hf_dataset_name(),
split=backend.hf_eval_split(), cache_dir=cache_dir)
train_hf = load_dataset(
backend.hf_dataset_name(), split=backend.hf_train_split(), cache_dir=cache_dir, trust_remote_code=True
)
eval_hf = load_dataset(
backend.hf_dataset_name(), split=backend.hf_eval_split(), cache_dir=cache_dir, trust_remote_code=True
)
keep = set(backend.keep_columns())
train_hf = train_hf.remove_columns([c for c in train_hf.column_names if c not in keep])
eval_hf = eval_hf.remove_columns([c for c in eval_hf.column_names if c not in keep])
train_hf = maybe_subsample(train_hf, args.max_train_samples, args.seed)
eval_hf = maybe_subsample(eval_hf, args.max_eval_samples, args.seed)

# ---- NVFlare init ----
flare.init()
site = flare.get_site_name()
writer = SummaryWriter()
site_id = parse_site_id(site)
train_hf = shard_dataset(train_hf, args.num_clients, site_id,
alpha=args.dirichlet_alpha, seed=args.seed)
eval_hf = shard_dataset(eval_hf, args.num_clients, site_id,
alpha=0.0, seed=args.seed) # eval always IID
train_hf = shard_dataset(train_hf, args.num_clients, site_id, alpha=args.dirichlet_alpha, seed=args.seed)
eval_hf = shard_dataset(eval_hf, args.num_clients, site_id, alpha=0.0, seed=args.seed) # eval always IID
print(f"[{site}] train={len(train_hf)}, eval={len(eval_hf)}", flush=True)

# ---- Model ----
device = "cuda" if torch.cuda.is_available() else "cpu"
model, processor = backend.build_model_and_processor(
args.model_name_or_path, args.lora_r, args.lora_alpha,
args.lora_dropout, device,
args.model_name_or_path,
args.lora_r,
args.lora_alpha,
args.lora_dropout,
device,
)
print(f"[{site}] {count_trainable_params(model)}", flush=True)

# ---- Dataloaders ----
train_ds = backend.build_dataset(train_hf, processor, args.max_q_len, args.max_a_len)
eval_ds = backend.build_dataset(eval_hf, processor, args.max_q_len, args.max_a_len)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers,
collate_fn=backend.collate_fn, pin_memory=True)
eval_loader = DataLoader(eval_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers,
collate_fn=backend.collate_fn, pin_memory=True)

# ---- FL loop ----
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=backend.collate_fn,
pin_memory=True,
)
eval_loader = DataLoader(
eval_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=backend.collate_fn,
pin_memory=True,
)

while flare.is_running():
input_model = flare.receive()
cur_round = getattr(input_model, "current_round", None)
cur_round = getattr(input_model, "current_round", None) or 0

if input_model and getattr(input_model, "params", None):
load_trainable_params(model, input_model.params, device)
Expand All @@ -142,32 +139,53 @@ def main() -> None:
if flare.is_evaluate():
acc = backend.evaluate(model, eval_loader, processor, device)
print(f"[{site}] validate round={cur_round} acc={acc:.4f}", flush=True)
flare.send(flare.FLModel(
params=None,
metrics={"val_accuracy": float(acc), "n_eval": len(eval_ds)},
meta={"n_eval": len(eval_ds)},
))
writer.add_scalar("val/acc", acc, cur_round)
flare.send(
flare.FLModel(
params=None,
metrics={"val_accuracy": float(acc), "n_eval": len(eval_ds)},
meta={"n_eval": len(eval_ds)},
)
)
continue

# -- train --
optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad), lr=args.lr)
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=args.lr)
steps_per_epoch = len(train_loader)
loss = 0.0
for _ in range(args.local_epochs):
loss = train_one_epoch(model, train_loader, optimizer, device,
args.grad_accum, backend)
for epoch in range(args.local_epochs):
print(f"[{site}] round={cur_round} epoch={epoch + 1}/{args.local_epochs}", flush=True)
global_step_offset = cur_round * args.local_epochs * steps_per_epoch + epoch * steps_per_epoch
loss = train_one_epoch(
model,
train_loader,
optimizer,
device,
args.grad_accum,
backend,
prefix=f"[{site}] round={cur_round}",
writer=writer,
global_step_offset=global_step_offset,
)
acc = backend.evaluate(model, eval_loader, processor, device)
steps = args.local_epochs * len(train_loader)
print(f"[{site}] train round={cur_round} loss={loss:.4f} acc={acc:.4f}",
flush=True)

flare.send(flare.FLModel(
params=get_trainable_params(model),
metrics={"train_loss": float(loss), "local_acc": float(acc),
"n_train": len(train_ds), "n_eval": len(eval_ds)},
meta={"NUM_STEPS_CURRENT_ROUND": steps,
"n_train": len(train_ds), "n_eval": len(eval_ds)},
))
steps = args.local_epochs * steps_per_epoch
print(f"[{site}] train round={cur_round} loss={loss:.4f} acc={acc:.4f}", flush=True)
writer.add_scalar("train/acc", acc, cur_round)

flare.send(
flare.FLModel(
params=get_trainable_params(model),
metrics={
"train_loss": float(loss),
"local_acc": float(acc),
"n_train": len(train_ds),
"n_eval": len(eval_ds),
},
meta={"NUM_STEPS_CURRENT_ROUND": steps, "n_train": len(train_ds), "n_eval": len(eval_ds)},
)
)

writer.close()


if __name__ == "__main__":
Expand Down
19 changes: 0 additions & 19 deletions research/fedumm/envs/env_blip.yml

This file was deleted.

Loading
Loading