Skip to content

Commit 053a578

Browse files
Merge branch 'main' into cherry-pick-4150
2 parents 2dc5edf + caa7254 commit 053a578

File tree

13 files changed

+256
-81
lines changed

13 files changed

+256
-81
lines changed

docs/programming_guide/component_configuration.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,15 @@ For example:
144144
"device": "cpu",
145145
"source_model": "model",
146146
"optimizer_args": {
147-
"path": "torch.optim.SGD",
147+
"class_path": "torch.optim.SGD",
148148
"args": {
149149
"lr": 1.0,
150150
"momentum": 0.6
151151
},
152152
"config_type": "dict"
153153
},
154154
"lr_scheduler_args": {
155-
"path": "torch.optim.lr_scheduler.CosineAnnealingLR",
155+
"class_path": "torch.optim.lr_scheduler.CosineAnnealingLR",
156156
"args": {
157157
"T_max": "{num_rounds}",
158158
"eta_min": 0.9
@@ -166,7 +166,7 @@ Notice the config:
166166
.. code-block:: json
167167
168168
"optimizer_args": {
169-
"path": "torch.optim.SGD",
169+
"class_path": "torch.optim.SGD",
170170
"args": {
171171
"lr": 1.0,
172172
"momentum": 0.6

docs/user_guide/data_scientist_guide/available_recipes.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ PyTorch FedOpt
258258
num_rounds=5,
259259
model=MyModel(),
260260
train_script="client.py",
261-
optimizer_args={"path": "torch.optim.SGD", "args": {"lr": 1.0, "momentum": 0.6}},
261+
optimizer_args={"class_path": "torch.optim.SGD", "args": {"lr": 1.0, "momentum": 0.6}},
262262
)
263263
env = SimEnv(num_clients=2)
264264
run = recipe.execute(env)
@@ -281,7 +281,7 @@ TensorFlow FedOpt
281281
num_rounds=5,
282282
model=MyTFModel(),
283283
train_script="client.py",
284-
optimizer_args={"path": "tensorflow.keras.optimizers.SGD", "args": {"learning_rate": 1.0}},
284+
optimizer_args={"class_path": "tensorflow.keras.optimizers.SGD", "args": {"learning_rate": 1.0}},
285285
)
286286
env = SimEnv(num_clients=2)
287287
run = recipe.execute(env)

examples/advanced/cifar10/pt/cifar10-sim/cifar10_fedopt/README.md

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,19 @@ python job.py --n_clients 16 --num_rounds 100 --alpha 0.1 --aggregation_epochs 2
8686

8787
### Server-Side Optimization
8888

89-
FedOpt is configured in `job.py` by specifying the server optimizer:
89+
FedOpt is configured in `job.py` using `FedOptRecipe` with `optimizer_args` and optional `lr_scheduler_args`:
9090

9191
```python
92-
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
92+
from nvflare.app_opt.pt.recipes.fedopt import FedOptRecipe
9393

94-
recipe = FedAvgRecipe(
94+
recipe = FedOptRecipe(
9595
name="cifar10_fedopt",
9696
# ... other parameters ...
97-
server_optimizer="sgd", # Optimizer type
98-
server_optimizer_args={
99-
"lr": 1.0, # Server learning rate
100-
"momentum": 0.9 # Momentum coefficient
101-
}
97+
optimizer_args={"class_path": "torch.optim.SGD", "args": {"lr": 1.0, "momentum": 0.6}},
98+
lr_scheduler_args={
99+
"class_path": "torch.optim.lr_scheduler.CosineAnnealingLR",
100+
"args": {"T_max": num_rounds, "eta_min": 0.9},
101+
},
102102
)
103103
```
104104

@@ -172,19 +172,18 @@ To try different server optimizers, modify `job.py`:
172172

173173
```python
174174
# Try Adam instead of SGD
175-
recipe = FedAvgRecipe(
175+
recipe = FedOptRecipe(
176176
# ... other parameters ...
177-
server_optimizer="adam",
178-
server_optimizer_args={
179-
"lr": 0.01,
180-
"betas": (0.9, 0.999)
181-
}
177+
optimizer_args={
178+
"class_path": "torch.optim.Adam",
179+
"args": {"lr": 0.01, "betas": (0.9, 0.999)},
180+
},
182181
)
183182
```
184183

185184
## References
186185

187186
- [FedOpt Paper](https://arxiv.org/abs/2003.00295) - Reddi et al., 2020
188187
- [NVFlare Documentation](https://nvflare.readthedocs.io/)
189-
- [NVFlare FedAvgRecipe](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_opt.pt.recipes.fedavg.html)
188+
- [NVFlare FedOptRecipe](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_opt.pt.recipes.fedopt.html)
190189

examples/advanced/cifar10/pt/cifar10-sim/cifar10_fedopt/job.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,10 @@ def main():
8080
model=ModerateCNN(),
8181
train_script=os.path.join(os.path.dirname(__file__), "client.py"),
8282
train_args=f"--train_idx_root {train_idx_root} --num_workers {num_workers} --lr {lr} --batch_size {batch_size} --aggregation_epochs {aggregation_epochs}",
83-
optimizer_args={"path": "torch.optim.SGD", "args": {"lr": 1.0, "momentum": 0.6}, "config_type": "dict"},
83+
optimizer_args={"class_path": "torch.optim.SGD", "args": {"lr": 1.0, "momentum": 0.6}},
8484
lr_scheduler_args={
85-
"path": "torch.optim.lr_scheduler.CosineAnnealingLR",
85+
"class_path": "torch.optim.lr_scheduler.CosineAnnealingLR",
8686
"args": {"T_max": num_rounds, "eta_min": 0.9},
87-
"config_type": "dict",
8887
},
8988
)
9089
add_experiment_tracking(recipe, tracking_type="tensorboard")

examples/advanced/cifar10/pt/src/data/cifar10_data_utils.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
# SOFTWARE.
3939

4040
import os
41+
import warnings
4142

4243
import numpy as np
4344
import torch
@@ -46,12 +47,17 @@
4647
from data.cifar10_dataset import CIFAR10_Idx
4748
from torchvision import transforms
4849

50+
# NumPy 2.4 deprecation when torchvision unpickles CIFAR batch files (align= in dtype)
51+
warnings.filterwarnings("ignore", message=".*align.*")
52+
4953
CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments
5054

5155

5256
def load_cifar10_data():
53-
# load data
54-
train_dataset = datasets.CIFAR10(root=CIFAR10_ROOT, train=True, download=True)
57+
# load data (suppress NumPy 2.4 dtype align deprecation from CIFAR pickle files)
58+
with warnings.catch_warnings():
59+
warnings.filterwarnings("ignore", message=".*align.*")
60+
train_dataset = datasets.CIFAR10(root=CIFAR10_ROOT, train=True, download=True)
5561

5662
# only training label is needed for doing split
5763
train_label = np.array(train_dataset.targets)
@@ -105,20 +111,22 @@ def create_datasets(site_name, train_idx_root, central=False):
105111
else:
106112
site_idx = None # use whole training dataset if central=True
107113

108-
train_dataset = CIFAR10_Idx(
109-
root=CIFAR10_ROOT,
110-
data_idx=site_idx,
111-
train=True,
112-
download=False,
113-
transform=transform_train,
114-
)
115-
116-
valid_dataset = torchvision.datasets.CIFAR10(
117-
root=CIFAR10_ROOT,
118-
train=False,
119-
download=False,
120-
transform=transform_valid,
121-
)
114+
with warnings.catch_warnings():
115+
warnings.filterwarnings("ignore", message=".*align.*")
116+
train_dataset = CIFAR10_Idx(
117+
root=CIFAR10_ROOT,
118+
data_idx=site_idx,
119+
train=True,
120+
download=False,
121+
transform=transform_train,
122+
)
123+
124+
valid_dataset = torchvision.datasets.CIFAR10(
125+
root=CIFAR10_ROOT,
126+
train=False,
127+
download=False,
128+
transform=transform_valid,
129+
)
122130

123131
return train_dataset, valid_dataset
124132

examples/advanced/cifar10/tf/cifar10_fedopt/job.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,21 @@ def main():
5858

5959
# Configure FedOpt optimizer arguments
6060
optimizer_args = {
61-
"path": "tensorflow.keras.optimizers.SGD",
61+
"class_path": "tensorflow.keras.optimizers.SGD",
6262
"args": {
6363
"learning_rate": args.server_lr,
6464
"momentum": args.server_momentum,
6565
},
66-
"config_type": "dict",
6766
}
6867

6968
# Configure FedOpt learning rate scheduler arguments
7069
lr_scheduler_args = {
71-
"path": "tensorflow.keras.optimizers.schedules.CosineDecay",
70+
"class_path": "tensorflow.keras.optimizers.schedules.CosineDecay",
7271
"args": {
7372
"initial_learning_rate": args.server_lr,
7473
"decay_steps": args.num_rounds,
7574
"alpha": args.server_lr_decay_alpha,
7675
},
77-
"config_type": "dict",
7876
}
7977

8078
# Create FedOpt recipe
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
nvflare~=2.7.2
1+
nvflare~=2.7.2rc
22
tensorflow[and-cuda]
33
filelock>=3.12.0

nvflare/app_opt/pt/fedopt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def __init__(
4545
4646
Args:
4747
optimizer_args: dictionary of optimizer arguments, e.g.
48-
{'path': 'torch.optim.SGD', 'args': {'lr': 1.0}} (default).
48+
{'class_path': 'torch.optim.SGD', 'args': {'lr': 1.0}} (default). 'path' is also accepted.
4949
lr_scheduler_args: dictionary of server-side learning rate scheduler arguments, e.g.
50-
{'path': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'args': {'T_max': 100}} (default: None).
50+
{'class_path': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'args': {'T_max': 100}} (default: None). 'path' is also accepted.
5151
source_model: either a valid torch model object or a component ID of a torch model object
5252
device: specify the device to run server-side optimization, e.g. "cpu" or "cuda:0"
5353
(will default to cuda if available and no device is specified).
@@ -62,13 +62,13 @@ def __init__(
6262

6363
if not isinstance(optimizer_args, dict):
6464
raise TypeError(
65-
"optimizer_args must be a dict of format, e.g. {'path': 'torch.optim.SGD', 'args': {'lr': 1.0}}."
65+
"optimizer_args must be a dict of format, e.g. {'class_path': 'torch.optim.SGD', 'args': {'lr': 1.0}}."
6666
)
6767
if lr_scheduler_args is not None:
6868
if not isinstance(lr_scheduler_args, dict):
6969
raise TypeError(
70-
"optimizer_args must be a dict of format, e.g. "
71-
"{'path': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'args': {'T_max': 100}}."
70+
"lr_scheduler_args must be a dict of format, e.g. "
71+
"{'class_path': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'args': {'T_max': 100}}."
7272
)
7373
self.source_model = source_model
7474
self.optimizer_args = optimizer_args
@@ -82,7 +82,7 @@ def __init__(
8282

8383
def _get_component_name(self, component_args):
8484
if component_args is not None:
85-
name = component_args.get("path", None)
85+
name = component_args.get("path") or component_args.get("class_path")
8686
if name is None:
8787
name = component_args.get("name", None)
8888
return name

nvflare/app_opt/pt/fedopt_ctl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def __init__(
2828
*args,
2929
source_model: Union[str, torch.nn.Module],
3030
optimizer_args: dict = {
31-
"path": "torch.optim.SGD",
31+
"class_path": "torch.optim.SGD",
3232
"args": {"lr": 1.0, "momentum": 0.6},
3333
},
3434
lr_scheduler_args: dict = {
35-
"path": "torch.optim.lr_scheduler.CosineAnnealingLR",
35+
"class_path": "torch.optim.lr_scheduler.CosineAnnealingLR",
3636
"args": {"T_max": 3, "eta_min": 0.9},
3737
},
3838
device=None,

nvflare/app_opt/pt/recipes/fedopt.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ class FedOptRecipe(Recipe):
7878
server_expected_format (str): What format to exchange the parameters between server and client.
7979
source_model (str): ID of the source model component. Defaults to "model".
8080
optimizer_args (dict): Configuration for server-side optimizer with keys:
81-
- path: Path to optimizer class (e.g., "torch.optim.SGD")
81+
- class_path: Fully qualified optimizer class (e.g., "torch.optim.SGD"). "path" is also accepted.
8282
- args: Dictionary of optimizer arguments (e.g., {"lr": 1.0, "momentum": 0.6})
83-
- config_type: Type of configuration, typically "dict"
83+
- config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
8484
lr_scheduler_args (dict): Optional configuration for learning rate scheduler with keys:
85-
- path: Path to scheduler class (e.g., "torch.optim.lr_scheduler.CosineAnnealingLR")
85+
- class_path: Fully qualified scheduler class (e.g., "torch.optim.lr_scheduler.CosineAnnealingLR"). "path" is also accepted.
8686
- args: Dictionary of scheduler arguments (e.g., {"T_max": 100, "eta_min": 0.9})
87-
- config_type: Type of configuration, typically "dict"
87+
- config_type: Optional; if omitted, set to "dict" so the config is not instantiated at load time.
8888
device (str): Device to use for server-side optimization, e.g. "cpu" or "cuda:0".
8989
Defaults to None; will default to cuda if available and no device is specified.
9090
server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server.
@@ -102,12 +102,12 @@ class FedOptRecipe(Recipe):
102102
device="cpu",
103103
source_model="model",
104104
optimizer_args={
105-
"path": "torch.optim.SGD",
105+
"class_path": "torch.optim.SGD",
106106
"args": {"lr": 1.0, "momentum": 0.6},
107107
"config_type": "dict"
108108
},
109109
lr_scheduler_args={
110-
"path": "torch.optim.lr_scheduler.CosineAnnealingLR",
110+
"class_path": "torch.optim.lr_scheduler.CosineAnnealingLR",
111111
"args": {"T_max": "{num_rounds}", "eta_min": 0.9},
112112
"config_type": "dict"
113113
}
@@ -158,7 +158,7 @@ def __init__(
158158
self.initial_ckpt = v.initial_ckpt
159159

160160
# Validate inputs using shared utilities
161-
from nvflare.recipe.utils import recipe_model_to_job_model, validate_ckpt
161+
from nvflare.recipe.utils import ensure_config_type_dict, recipe_model_to_job_model, validate_ckpt
162162

163163
validate_ckpt(self.initial_ckpt)
164164
if isinstance(self.model, dict):
@@ -174,8 +174,10 @@ def __init__(
174174
self.server_expected_format: ExchangeFormat = v.server_expected_format
175175
self.device = device
176176
self.source_model = source_model
177-
self.optimizer_args = optimizer_args
178-
self.lr_scheduler_args = lr_scheduler_args
177+
# Ensure config_type "dict" so the component builder does not try to instantiate
178+
# optimizer/scheduler at config load time (params/optimizer are set at runtime).
179+
self.optimizer_args = ensure_config_type_dict(optimizer_args)
180+
self.lr_scheduler_args = ensure_config_type_dict(lr_scheduler_args)
179181
self.server_memory_gc_rounds = v.server_memory_gc_rounds
180182

181183
# Replace {num_rounds} placeholder if present in lr_scheduler_args

0 commit comments

Comments
 (0)