Skip to content

Commit 2dc5edf

Browse files
Merge branch 'main' into cherry-pick-4150
2 parents cf0a5a0 + 3879ce6 commit 2dc5edf

File tree

14 files changed

+923
-52
lines changed

14 files changed

+923
-52
lines changed

AGENTS.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# NVFlare Agent Notes
2+
3+
- To trigger CI/CD from a PR review thread, post a single-line comment exactly: `/build`.
4+
- Prefer `rg` and `rg --files` for fast codebase search.
5+
- Keep edits scoped to the task; do not modify unrelated files in a dirty worktree.
6+
- Start with targeted tests for changed files, then run broader checks as needed.
7+
8+
## Fast Commands
9+
10+
- `./runtest.sh` runs license/style/tests with coverage.
11+
- `./runtest.sh -s` runs style checks (flake8, black, isort).
12+
- `./runtest.sh -f` auto-fixes style where possible.
13+
- `./runtest.sh -u` runs unit tests.
14+
- `python3 -m pytest tests/unit_test/path/to/test_file.py -v` runs one test file.
15+
- `python3 -m pytest --numprocesses=8 -v tests/unit_test` runs unit tests in parallel.
16+
- `./build_doc.sh --html` builds docs.
17+
- `./build_doc.sh --clean` cleans docs build artifacts.
18+
19+
## Style and Testing Conventions
20+
21+
- Format/lint stack: black (line length 120), flake8, isort (black profile).
22+
- Python support targets: 3.9, 3.10, 3.11, 3.12.
23+
- Add the standard NVIDIA Apache-2.0 license header to new Python source files.
24+
- Unit tests live in `tests/unit_test/`; integration tests live in `tests/integration_test/`.
25+
- Test file names follow `[module_name]_test.py`.
26+
27+
## Quick Package Map
28+
29+
- `nvflare/apis/`: core interfaces (Controller, Executor, Task, Shareable, FLContext).
30+
- `nvflare/app_common/`: common algorithms and utilities.
31+
- `nvflare/app_opt/`: optional integrations/dependencies.
32+
- `nvflare/client/`: client-side APIs.
33+
- `nvflare/job_config/`: FedJob/job configuration.
34+
- `nvflare/private/`: internal implementations.
35+
- `nvflare/fuel/`: shared infrastructure utilities.

examples/advanced/bionemo/downstream/client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,23 @@
2222
from pathlib import Path
2323
from typing import Optional
2424

25+
from nvflare.fuel.utils.network_utils import get_open_ports
26+
2527
# Set NumExpr thread limits before importing numerical libraries to avoid thread conflicts
2628
os.environ.setdefault("NUMEXPR_MAX_THREADS", "64")
2729
os.environ.setdefault("NUMEXPR_NUM_THREADS", "8")
2830

31+
# Use an available port for PyTorch distributed to avoid EADDRINUSE and PID collisions.
32+
if "MASTER_PORT" not in os.environ:
33+
os.environ["MASTER_PORT"] = str(get_open_ports(1)[0])
34+
if "MASTER_ADDR" not in os.environ:
35+
os.environ.setdefault("MASTER_ADDR", "localhost")
36+
2937
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
3038
from bionemo.esm2.data.tokenizer import get_tokenizer
3139
from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
3240
from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset, InMemorySingleValueDataset
3341
from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig
34-
35-
# Resue parser and config constants from bionemo
3642
from bionemo.esm2.scripts.finetune_esm2 import get_parser
3743
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
3844
from bionemo.llm.model.biobert.model import BioBertConfig
@@ -413,6 +419,8 @@ def train_model(
413419
)
414420

415421
# perform local training starting with the received global model
422+
# Set MASTER_PORT so the training subprocess (spawned by Lightning) inherits an available port.
423+
os.environ["MASTER_PORT"] = str(get_open_ports(1)[0])
416424
llm.train(
417425
model=module,
418426
data=data_module,

examples/advanced/bionemo/downstream/downstream_nvflare.ipynb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@
5858
"warnings.simplefilter(\"ignore\")"
5959
]
6060
},
61+
{
62+
"cell_type": "markdown",
63+
"metadata": {},
64+
"source": [
65+
"Copy the model.py files to each tasks subfolders for execution."
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"! for d in tap sabdab scl; do cp model.py \"$d/\"; done"
75+
]
76+
},
6177
{
6278
"cell_type": "markdown",
6379
"metadata": {},
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS FOR ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ESM2 server module: loads checkpoint state_dict for NVFlare FedAvg (no Megatron/Lightning init)."""
15+
16+
import os
17+
import warnings
18+
from collections import OrderedDict
19+
from typing import List, NamedTuple, Optional
20+
21+
22+
class _IncompatibleKeys(NamedTuple):
23+
"""Compatible with PyTorch's load_state_dict return type (missing_keys, unexpected_keys)."""
24+
25+
missing_keys: List[str]
26+
unexpected_keys: List[str]
27+
28+
29+
import torch
30+
31+
from nvflare.fuel.utils.network_utils import get_open_ports
32+
33+
34+
def _checkpoint_key_to_client(k: str) -> str:
35+
for old, new in (
36+
("encoder.layers.self_attention.", "encoder.layers.0.self_attention."),
37+
("encoder.layers.mlp.", "encoder.layers.0.mlp."),
38+
):
39+
if old in k:
40+
k = k.replace(old, new, 1)
41+
return k
42+
43+
44+
def _expand_checkpoint_state_dict(sd: OrderedDict) -> OrderedDict:
45+
"""Split layer-stacked tensors [n, ...] into per-layer keys (layers.0.*, layers.1.*, ...)."""
46+
out = OrderedDict()
47+
for k, v in sd.items():
48+
if not isinstance(v, torch.Tensor):
49+
out[k] = v
50+
continue
51+
# Keys that are layer-stacked: encoder.layers.self_attention.* or encoder.layers.mlp.*
52+
if "encoder.layers.self_attention." not in k and "encoder.layers.mlp." not in k:
53+
out[_checkpoint_key_to_client(k)] = v
54+
continue
55+
if v.ndim < 1:
56+
out[_checkpoint_key_to_client(k)] = v
57+
continue
58+
num_layers = v.shape[0]
59+
# Split into per-layer keys
60+
if "encoder.layers.self_attention." in k:
61+
base = k.replace("encoder.layers.self_attention.", "encoder.layers.{}.self_attention.", 1)
62+
else:
63+
base = k.replace("encoder.layers.mlp.", "encoder.layers.{}.mlp.", 1)
64+
for i in range(num_layers):
65+
out[base.format(i)] = v[i].clone()
66+
return out
67+
68+
69+
class ESM2ModuleForServer(torch.nn.Module):
70+
"""Holds state_dict loaded from checkpoint; BioNeMoParamsFilter adds prefix when sending to client."""
71+
72+
def __init__(self, checkpoint_path: str, **kwargs):
73+
super().__init__()
74+
path = os.path.abspath(checkpoint_path)
75+
if not os.path.isfile(path) and not os.path.isdir(path):
76+
raise FileNotFoundError(f"Checkpoint path does not exist or is not a file/directory: {checkpoint_path!r}")
77+
sd = load_state_dict_from_checkpoint_path(checkpoint_path)
78+
if sd is None:
79+
raise ValueError(
80+
f"Checkpoint is missing or invalid (could not load state dict from {checkpoint_path!r}). "
81+
"Ensure the path points to a valid NeMo or PyTorch checkpoint."
82+
)
83+
self._state_dict = _expand_checkpoint_state_dict(sd)
84+
85+
@staticmethod
86+
def _stored_key(k: str) -> str:
87+
if k.startswith("module.module."):
88+
return k[len("module.") :]
89+
return k
90+
91+
def state_dict(self, *args, **kwargs):
92+
return OrderedDict(self._state_dict)
93+
94+
def load_state_dict(self, state_dict, strict: bool = True):
95+
self._state_dict = OrderedDict((self._stored_key(k), v) for k, v in state_dict.items())
96+
return _IncompatibleKeys(missing_keys=[], unexpected_keys=[])
97+
98+
99+
def _flatten_state_dict(d: dict, prefix: str = "") -> OrderedDict:
100+
out = OrderedDict()
101+
for k, v in d.items():
102+
key = f"{prefix}.{k}" if prefix else k
103+
if isinstance(v, torch.Tensor):
104+
out[key] = v
105+
elif isinstance(v, (dict, OrderedDict)):
106+
out.update(_flatten_state_dict(v, key))
107+
return out
108+
109+
110+
def _extract_state_dict(loaded: dict) -> Optional[OrderedDict]:
111+
d = loaded
112+
for key in ("model", "state_dict", "weights", "checkpoint"):
113+
if key in d and isinstance(d[key], (dict, OrderedDict)):
114+
d = d[key]
115+
break
116+
if d is None or not d:
117+
return None
118+
if all(isinstance(v, torch.Tensor) for v in d.values()):
119+
return OrderedDict(d)
120+
flat = _flatten_state_dict(d)
121+
if flat is None or not flat:
122+
return None
123+
if all(isinstance(v, torch.Tensor) for v in flat.values()):
124+
return flat
125+
return None
126+
127+
128+
def _load_nemo_distributed_checkpoint(path: str) -> Optional[OrderedDict]:
129+
weights_dir = os.path.join(path, "weights")
130+
if not os.path.isdir(weights_dir):
131+
return None
132+
files = os.listdir(weights_dir)
133+
if "metadata.json" not in files or not any(f.endswith(".distcp") for f in files):
134+
return None
135+
try:
136+
from megatron.core.dist_checkpointing.serialization import load_plain_tensors
137+
except ImportError:
138+
try:
139+
from megatron.core import dist_checkpointing as dist_ckpt
140+
141+
load_plain_tensors = getattr(dist_ckpt, "load_plain_tensors", None)
142+
except ImportError:
143+
load_plain_tensors = None
144+
if load_plain_tensors is None:
145+
return None
146+
we_initialized = not torch.distributed.is_initialized()
147+
if we_initialized:
148+
os.environ.setdefault("MASTER_ADDR", "localhost")
149+
os.environ.setdefault("MASTER_PORT", str(get_open_ports(1)[0]))
150+
torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1)
151+
try:
152+
ckpt_dir = os.path.abspath(weights_dir)
153+
loaded_sd = load_plain_tensors(ckpt_dir)
154+
if not isinstance(loaded_sd, dict):
155+
return None
156+
out = OrderedDict((k, v.cpu() if v.is_cuda else v) for k, v in loaded_sd.items() if isinstance(v, torch.Tensor))
157+
return out if out else None
158+
except Exception as e:
159+
warnings.warn(f"NeMo distributed checkpoint load failed: {e}", UserWarning, stacklevel=2)
160+
return None
161+
finally:
162+
if we_initialized and torch.distributed.is_initialized():
163+
torch.distributed.destroy_process_group()
164+
165+
166+
def load_state_dict_from_checkpoint_path(checkpoint_path: str) -> Optional[OrderedDict]:
167+
"""Load a state dict from a NeMo/Megatron checkpoint file or directory.
168+
169+
Supports single-file checkpoints and NeMo distributed checkpoint directories.
170+
Uses ``torch.load(..., weights_only=False)`` so that non-tensor objects in
171+
NeMo/Megatron checkpoints are restored correctly.
172+
173+
.. note::
174+
``weights_only=False`` uses Python's pickle module, which can execute
175+
arbitrary code during deserialization. Only load checkpoints from
176+
trusted sources.
177+
"""
178+
path = os.path.abspath(checkpoint_path)
179+
loaded = None
180+
if os.path.isfile(path):
181+
try:
182+
loaded = torch.load(path, map_location="cpu", weights_only=False)
183+
except Exception:
184+
return None
185+
elif os.path.isdir(path):
186+
result = _load_nemo_distributed_checkpoint(path)
187+
if result is not None:
188+
return result
189+
candidate = os.path.join(path, "weights", "common.pt")
190+
if os.path.isfile(candidate):
191+
try:
192+
loaded = torch.load(candidate, map_location="cpu", weights_only=False)
193+
except Exception:
194+
pass
195+
if loaded is None or not isinstance(loaded, dict):
196+
return None
197+
return _extract_state_dict(loaded)

examples/advanced/bionemo/downstream/sabdab/job.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818

1919
from bionemo.core.data.load import load
2020

21-
from nvflare.app_common.widgets.decomposer_reg import DecomposerRegister
2221
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
2322
from nvflare.recipe import SimEnv
2423

2524
# BioNeMo requires heavy imports (PyTorch, NeMo, Megatron) which can take longer than
2625
# the default 300s timeout on systems with slow I/O or resource contention
2726
BIONEMO_EXTERNAL_PRE_INIT_TIMEOUT = 900.0 # 15 minutes
2827

28+
# isort: off
2929
sys.path.append(os.path.join(os.getcwd(), "..")) # include parent folder in path
3030
from bionemo_filters import BioNeMoParamsFilter, BioNeMoStateDictFilter
3131

32+
# isort: on
33+
3234

3335
def main(args):
3436
checkpoint_path = load(f"esm2/{args.model}:2.0")
@@ -54,10 +56,17 @@ def main(args):
5456
script_args = f"--restore-from-checkpoint-path {checkpoint_path} --train-data-path /tmp/placeholder --valid-data-path /tmp/placeholder --config-class ESM2FineTuneSeqConfig --dataset-class InMemorySingleValueDataset --task-type classification --mlp-ft-dropout 0.1 --mlp-hidden-size 256 --mlp-target-size 2 --experiment-name sabdab_esm2_{args.model} --num-steps {args.local_steps} --num-gpus 1 --val-check-interval {val_check_interval} --log-every-n-steps 10 --lr 1e-4 --lr-multiplier 5 --scale-lr-layer classification_head --result-dir bionemo --micro-batch-size 64 --precision {precision} --save-top-k 1 --limit-val-batches 1.0 --classes {classes} --dataset-name sabdab --exp-name {args.exp_name}"
5557
print(f"Running {args.train_script} with base args (data paths will be resolved per-client)")
5658

59+
# Use dict config of the model so we only instantiate the model on the server.
60+
model = {
61+
"class_path": "model.ESM2ModuleForServer",
62+
"args": {"checkpoint_path": str(checkpoint_path)},
63+
}
64+
5765
# Create FedAvgRecipe
5866
job_name = f"{args.exp_name}_sabdab_esm2_{args.model}"
5967
recipe = FedAvgRecipe(
6068
name=job_name,
69+
model=model,
6170
min_clients=args.num_clients,
6271
num_rounds=args.num_rounds,
6372
train_script=f"../{args.train_script}",
@@ -73,10 +82,6 @@ def main(args):
7382
recipe.add_client_input_filter(BioNeMoParamsFilter(precision), tasks=["train", "validate"])
7483
recipe.add_client_output_filter(BioNeMoStateDictFilter(), tasks=["train", "validate"])
7584

76-
# Add decomposer register to server and clients
77-
recipe.job.to_server(DecomposerRegister(["nvflare.app_opt.pt.decomposers.TensorDecomposer"]))
78-
recipe.job.to_clients(DecomposerRegister(["nvflare.app_opt.pt.decomposers.TensorDecomposer"]))
79-
8085
# Add BioNeMo-specific timeout configuration to client config to override its default timeout
8186
recipe.add_client_config({"EXTERNAL_PRE_INIT_TIMEOUT": BIONEMO_EXTERNAL_PRE_INIT_TIMEOUT})
8287

examples/advanced/bionemo/downstream/scl/job.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818

1919
from bionemo.core.data.load import load
2020

21-
from nvflare.app_common.widgets.decomposer_reg import DecomposerRegister
2221
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
2322
from nvflare.recipe import SimEnv
2423

2524
# BioNeMo requires heavy imports (PyTorch, NeMo, Megatron) which can take longer than
2625
# the default 300s timeout on systems with slow I/O or resource contention
2726
BIONEMO_EXTERNAL_PRE_INIT_TIMEOUT = 900.0 # 15 minutes
2827

28+
# isort: off
2929
sys.path.append(os.path.join(os.getcwd(), "..")) # include parent folder in path
3030
from bionemo_filters import BioNeMoParamsFilter, BioNeMoStateDictFilter
3131

32+
# isort: on
33+
3234

3335
def main(args):
3436
checkpoint_path = load(f"esm2/{args.model}:2.0")
@@ -54,10 +56,16 @@ def main(args):
5456
script_args = f"--restore-from-checkpoint-path {checkpoint_path} --train-data-path /tmp/placeholder --valid-data-path /tmp/placeholder --config-class ESM2FineTuneSeqConfig --dataset-class InMemorySingleValueDataset --task-type classification --mlp-ft-dropout 0.1 --mlp-hidden-size 256 --mlp-target-size 10 --experiment-name scl_esm2_{args.model} --num-steps {args.local_steps} --num-gpus 1 --val-check-interval {val_check_interval} --log-every-n-steps 10 --lr 5e-4 --result-dir bionemo --micro-batch-size 64 --precision {precision} --save-top-k 1 --encoder-frozen --limit-val-batches 1.0 --classes {classes} --dataset-name scl --exp-name {args.exp_name}"
5557
print(f"Running {args.train_script} with base args (data paths will be resolved per-client)")
5658

57-
# Create FedAvgRecipe
59+
# Use dict config of the model so we only instantiate the model on the server.
60+
model = {
61+
"class_path": "model.ESM2ModuleForServer",
62+
"args": {"checkpoint_path": str(checkpoint_path)},
63+
}
64+
5865
job_name = f"{args.exp_name}_scl_esm2_{args.model}"
5966
recipe = FedAvgRecipe(
6067
name=job_name,
68+
model=model,
6169
min_clients=args.num_clients,
6270
num_rounds=args.num_rounds,
6371
train_script=f"../{args.train_script}",
@@ -73,10 +81,6 @@ def main(args):
7381
recipe.add_client_input_filter(BioNeMoParamsFilter(precision), tasks=["train", "validate"])
7482
recipe.add_client_output_filter(BioNeMoStateDictFilter(), tasks=["train", "validate"])
7583

76-
# Add decomposer register to server and clients
77-
recipe.job.to_server(DecomposerRegister(["nvflare.app_opt.pt.decomposers.TensorDecomposer"]))
78-
recipe.job.to_clients(DecomposerRegister(["nvflare.app_opt.pt.decomposers.TensorDecomposer"]))
79-
8084
# Add BioNeMo-specific timeout configuration to client config to override its default timeout
8185
recipe.add_client_config({"EXTERNAL_PRE_INIT_TIMEOUT": BIONEMO_EXTERNAL_PRE_INIT_TIMEOUT})
8286

0 commit comments

Comments
 (0)