Skip to content

Commit ecf608c

Browse files
ZiyueXu77claude
andauthored
Cherry pick Recipe API fixes (#4192)
Fixes # . ### Description [pr 4183](#4183) to main ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent e478063 commit ecf608c

File tree

4 files changed

+201
-35
lines changed

4 files changed

+201
-35
lines changed

examples/advanced/llm_hf/MULTINODE.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ SLURM Job (2 nodes allocated)
118118
- Uses `FedAvgRecipe` with `per_site_config` for multi-node setup
119119
- When `--multi_node` flag is set:
120120
- Sets command in per_site_config: `"command": "bash custom/client_wrapper.sh"`
121-
- Adds wrapper script to job: `recipe.job.to("client_wrapper.sh", site_name)`
121+
- Adds wrapper script to job: `recipe.add_client_file("client_wrapper.sh")`
122122
- Script arguments passed via `"train_args"` in per_site_config
123123
- No need to handle environment variables in Python
124124

@@ -131,7 +131,7 @@ SLURM Job (2 nodes allocated)
131131
- Uses `CUDA_VISIBLE_DEVICES` to set GPUs. Assumes that they are set as comma-separated list, e.g. "0,1,2,3,4,5,6,7".
132132

133133
**Why this works:**
134-
- The wrapper script is included in the FL job package via `recipe.job.to("client_wrapper.sh", site_name)`
134+
- The wrapper script is included in the FL job package via `recipe.add_client_file("client_wrapper.sh")`
135135
- It's placed in the `custom/` subdirectory of the job workspace
136136
- Command is set to `bash custom/client_wrapper.sh` in the per_site_config
137137
- It runs in the same environment as the SLURM job (has access to `srun` and SLURM variables)
@@ -164,7 +164,7 @@ SLURM Job (2 nodes allocated)
164164
- Uses `FedAvgRecipe` to configure the federated learning job
165165
- For multi-node mode (`--multi_node` flag):
166166
- Sets command via `per_site_config`: `"command": "bash custom/client_wrapper.sh"`
167-
- Adds wrapper script: `recipe.job.to("client_wrapper.sh", site_name)`
167+
- Adds wrapper script: `recipe.add_client_file("client_wrapper.sh")`
168168
- Includes `client.py` training script automatically via recipe
169169
- Exports job configuration to specified directory
170170

examples/advanced/llm_hf/job.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,30 +181,24 @@ def main():
181181
)
182182

183183
# Add client params to reduce timeout failures for longer LLM runs
184-
for site_name in client_names:
185-
client_params = {"get_task_timeout": 300, "submit_task_result_timeout": 300}
186-
recipe.job.to(client_params, site_name)
184+
recipe.add_client_config({"get_task_timeout": 300, "submit_task_result_timeout": 300})
187185

188186
# Add client_wrapper.sh for multi-node training
189187
if args.multi_node:
190-
for site_name in client_names:
191-
recipe.job.to("client_wrapper.sh", site_name)
188+
recipe.add_client_file("client_wrapper.sh")
192189

193190
# Add quantization filters if specified
194191
if args.quantize_mode:
195-
from nvflare import FilterType
196-
197192
quantizer = ModelQuantizer(quantization_type=args.quantize_mode.lower())
198193
dequantizer = ModelDequantizer()
199194

200-
# Add to server
201-
recipe.job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
202-
recipe.job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
195+
# Add to server: quantizer on output, dequantizer on input
196+
recipe.add_server_output_filter(quantizer, tasks=["train"])
197+
recipe.add_server_input_filter(dequantizer, tasks=["train"])
203198

204-
# Add to all clients
205-
for site_name in client_names:
206-
recipe.job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
207-
recipe.job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
199+
# Add to all clients: quantizer on output, dequantizer on input
200+
recipe.add_client_output_filter(quantizer, tasks=["train"])
201+
recipe.add_client_input_filter(dequantizer, tasks=["train"])
208202

209203
# Add experiment tracking if requested
210204
if args.use_tracking:

nvflare/recipe/spec.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,36 @@ def process_env(self, env: ExecEnv):
114114
"""
115115
pass
116116

117+
def _add_to_client_apps(self, obj, clients: Optional[List[str]] = None, **kwargs):
118+
"""Add an object to client apps, preserving existing per-site structure.
119+
120+
Args:
121+
obj: Object to add to clients.
122+
clients: Optional list of specific client names. If None, applies to all clients.
123+
**kwargs: Extra options forwarded to `job.to()`/`job.to_clients()`.
124+
"""
125+
if clients is None:
126+
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
127+
from nvflare.job_config.defs import JobTargetType
128+
129+
# FedJob has no public API to list per-site deploy targets, so we inspect
130+
# private deploy map to preserve existing per-site client topology.
131+
deploy_map = getattr(self.job, "_deploy_map", {})
132+
existing_client_sites = [
133+
target
134+
for target in deploy_map.keys()
135+
if target not in [ALL_SITES, SERVER_SITE_NAME]
136+
and JobTargetType.get_target_type(target) == JobTargetType.CLIENT
137+
]
138+
if existing_client_sites:
139+
for site in existing_client_sites:
140+
self.job.to(obj, site, **kwargs)
141+
else:
142+
self.job.to_clients(obj, **kwargs)
143+
else:
144+
for client in clients:
145+
self.job.to(obj, client, **kwargs)
146+
117147
def add_client_input_filter(
118148
self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None
119149
):
@@ -127,11 +157,7 @@ def add_client_input_filter(
127157
Returns: None
128158
129159
"""
130-
if clients is None:
131-
self.job.to_clients(filter, filter_type=FilterType.TASK_DATA, tasks=tasks)
132-
else:
133-
for client in clients:
134-
self.job.to(filter, client, filter_type=FilterType.TASK_DATA, tasks=tasks)
160+
self._add_to_client_apps(filter, clients=clients, filter_type=FilterType.TASK_DATA, tasks=tasks)
135161

136162
def add_client_output_filter(
137163
self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None
@@ -146,11 +172,7 @@ def add_client_output_filter(
146172
Returns: None
147173
148174
"""
149-
if clients is None:
150-
self.job.to_clients(filter, filter_type=FilterType.TASK_RESULT, tasks=tasks)
151-
else:
152-
for client in clients:
153-
self.job.to(filter, client, filter_type=FilterType.TASK_RESULT, tasks=tasks)
175+
self._add_to_client_apps(filter, clients=clients, filter_type=FilterType.TASK_RESULT, tasks=tasks)
154176

155177
def add_client_config(self, config: Dict, clients: Optional[List[str]] = None):
156178
"""Add top-level configuration parameters to config_fed_client.json.
@@ -165,11 +187,32 @@ def add_client_config(self, config: Dict, clients: Optional[List[str]] = None):
165187
if not isinstance(config, dict):
166188
raise TypeError(f"config must be a dict, got {type(config).__name__}")
167189

168-
if clients is None:
169-
self.job.to_clients(config)
170-
else:
171-
for client in clients:
172-
self.job.to(obj=config, target=client)
190+
self._add_to_client_apps(config, clients=clients)
191+
192+
def add_client_file(self, file_path: str, clients: Optional[List[str]] = None):
193+
"""Add a file or directory to client apps.
194+
195+
The file will be added to the client's custom directory and bundled with the job.
196+
Can be a script, configuration file, or any resource needed by clients.
197+
198+
Args:
199+
file_path: Path to the file or directory to add to clients.
200+
clients: Optional list of specific client names. If None, applies to all clients.
201+
202+
Raises:
203+
TypeError: If file_path is not a string.
204+
205+
Example:
206+
# Add a wrapper script to all clients
207+
recipe.add_client_file("client_wrapper.sh")
208+
209+
# Add a script to specific clients
210+
recipe.add_client_file("custom_script.py", clients=["site1", "site2"])
211+
"""
212+
if not isinstance(file_path, str):
213+
raise TypeError(f"file_path must be a str, got {type(file_path).__name__}")
214+
215+
self._add_to_client_apps(file_path, clients=clients)
173216

174217
def add_server_output_filter(self, filter: Filter, tasks: Optional[List[str]] = None):
175218
"""Add a filter to the server for outgoing tasks to clients.
@@ -209,6 +252,27 @@ def add_server_config(self, config: Dict):
209252

210253
self.job.to_server(config)
211254

255+
def add_server_file(self, file_path: str):
256+
"""Add a file or directory to server app.
257+
258+
The file will be added to the server's custom directory and bundled with the job.
259+
Can be a script, configuration file, or any resource needed by the server.
260+
261+
Args:
262+
file_path: Path to the file or directory to add to server.
263+
264+
Raises:
265+
TypeError: If file_path is not a string.
266+
267+
Example:
268+
# Add a wrapper script to server
269+
recipe.add_server_file("server_wrapper.sh")
270+
"""
271+
if not isinstance(file_path, str):
272+
raise TypeError(f"file_path must be a str, got {type(file_path).__name__}")
273+
274+
self.job.to_server(file_path)
275+
212276
@staticmethod
213277
def _get_full_class_name(obj):
214278
"""
@@ -243,7 +307,8 @@ def add_decomposers(self, decomposers: List[Union[str, Decomposer]]):
243307

244308
reg = DecomposerRegister(class_names)
245309
self.job.to_server(reg, id="decomposer_reg")
246-
self.job.to_clients(reg, id="decomposer_reg")
310+
311+
self._add_to_client_apps(reg, id="decomposer_reg")
247312

248313
def export(
249314
self,
@@ -267,7 +332,7 @@ def export(
267332
self.job.to_server(server_exec_params)
268333

269334
if client_exec_params:
270-
self.job.to_clients(client_exec_params)
335+
self._add_to_client_apps(client_exec_params)
271336

272337
if env:
273338
self.process_env(env)
@@ -291,7 +356,7 @@ def execute(
291356
self.job.to_server(server_exec_params)
292357

293358
if client_exec_params:
294-
self.job.to_clients(client_exec_params)
359+
self._add_to_client_apps(client_exec_params)
295360

296361
self.process_env(env)
297362
job_id = env.deploy(self.job)

tests/unit_test/recipe/spec_test.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,113 @@ def test_add_client_config(self, temp_script):
212212
assert all_clients_app is not None
213213
assert all_clients_app.app_config.additional_params == config
214214

215+
def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script):
216+
"""Test add_client_file stores file paths in ext_scripts and dirs in ext_dirs."""
217+
from nvflare.apis.job_def import ALL_SITES
218+
from nvflare.fuel.utils.constants import FrameworkType
219+
from nvflare.recipe.fedavg import FedAvgRecipe
220+
221+
recipe = FedAvgRecipe(
222+
name="test_job_files",
223+
num_rounds=2,
224+
min_clients=2,
225+
train_script=temp_script,
226+
initial_ckpt="/abs/path/to/model.npy",
227+
framework=FrameworkType.NUMPY,
228+
)
229+
230+
with tempfile.TemporaryDirectory() as temp_dir:
231+
recipe.add_client_file(temp_script)
232+
recipe.add_client_file(temp_dir)
233+
234+
all_clients_app = recipe.job._deploy_map.get(ALL_SITES)
235+
assert all_clients_app is not None
236+
assert temp_script in all_clients_app.app_config.ext_scripts
237+
assert temp_dir in all_clients_app.app_config.ext_dirs
238+
239+
def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp_script):
240+
"""Test add_client_file keeps per-site topology and does not create ALL_SITES app."""
241+
from nvflare.apis.job_def import ALL_SITES
242+
from nvflare.fuel.utils.constants import FrameworkType
243+
from nvflare.recipe.fedavg import FedAvgRecipe
244+
245+
recipe = FedAvgRecipe(
246+
name="test_job_per_site_files",
247+
num_rounds=2,
248+
min_clients=2,
249+
train_script=temp_script,
250+
initial_ckpt="/abs/path/to/model.npy",
251+
framework=FrameworkType.NUMPY,
252+
per_site_config={"site-1": {}, "site-2": {}},
253+
)
254+
255+
recipe.add_client_file(temp_script)
256+
257+
assert ALL_SITES not in recipe.job._deploy_map
258+
site_1_app = recipe.job._deploy_map.get("site-1")
259+
site_2_app = recipe.job._deploy_map.get("site-2")
260+
assert site_1_app is not None
261+
assert site_2_app is not None
262+
assert temp_script in site_1_app.app_config.ext_scripts
263+
assert temp_script in site_2_app.app_config.ext_scripts
264+
265+
def test_add_client_file_with_specific_clients_only_updates_selected_sites(self, temp_script):
266+
"""Test add_client_file(..., clients=[...]) only adds file to specified sites."""
267+
from nvflare.fuel.utils.constants import FrameworkType
268+
from nvflare.recipe.fedavg import FedAvgRecipe
269+
270+
recipe = FedAvgRecipe(
271+
name="test_job_targeted_files",
272+
num_rounds=2,
273+
min_clients=2,
274+
train_script=temp_script,
275+
initial_ckpt="/abs/path/to/model.npy",
276+
framework=FrameworkType.NUMPY,
277+
per_site_config={"site-1": {}, "site-2": {}, "site-3": {}},
278+
)
279+
280+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
281+
f.write("targeted file for site-2 only")
282+
targeted_file = f.name
283+
284+
recipe.add_client_file(targeted_file, clients=["site-2"])
285+
286+
try:
287+
site_1_app = recipe.job._deploy_map.get("site-1")
288+
site_2_app = recipe.job._deploy_map.get("site-2")
289+
site_3_app = recipe.job._deploy_map.get("site-3")
290+
assert site_1_app is not None
291+
assert site_2_app is not None
292+
assert site_3_app is not None
293+
assert targeted_file not in site_1_app.app_config.ext_scripts
294+
assert targeted_file in site_2_app.app_config.ext_scripts
295+
assert targeted_file not in site_3_app.app_config.ext_scripts
296+
finally:
297+
os.unlink(targeted_file)
298+
299+
def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_script):
300+
"""Test add_server_file stores file paths in ext_scripts and dirs in ext_dirs."""
301+
from nvflare.fuel.utils.constants import FrameworkType
302+
from nvflare.recipe.fedavg import FedAvgRecipe
303+
304+
recipe = FedAvgRecipe(
305+
name="test_job_server_files",
306+
num_rounds=2,
307+
min_clients=2,
308+
train_script=temp_script,
309+
initial_ckpt="/abs/path/to/model.npy",
310+
framework=FrameworkType.NUMPY,
311+
)
312+
313+
with tempfile.TemporaryDirectory() as temp_dir:
314+
recipe.add_server_file(temp_script)
315+
recipe.add_server_file(temp_dir)
316+
317+
server_app = recipe.job._deploy_map.get("server")
318+
assert server_app is not None
319+
assert temp_script in server_app.app_config.ext_scripts
320+
assert temp_dir in server_app.app_config.ext_dirs
321+
215322
def test_config_in_generated_json(self, temp_script):
216323
"""Test that configs appear in generated JSON files."""
217324
from nvflare.fuel.utils.constants import FrameworkType

0 commit comments

Comments
 (0)