Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions examples/hello-world/hello-numpy-cross-val/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,41 +53,72 @@ def main():

print(f"Client {client_name} initialized")

# Track last trained params so we can submit them when CSE asks for our local model.
last_params = None

while flare.is_running():
# Receive model from server
input_model = flare.receive()
print(f"Client {client_name}, current_round={input_model.current_round}")

# Get model parameters
input_np_arr = input_model.params[NPConstants.NUMPY_KEY]
print(f"Received weights: {input_np_arr}")

# Train the model
new_params = train(input_np_arr)

# Evaluate the model
metrics = evaluate(new_params)
print(f"Client {client_name} evaluation metrics: {metrics}")
print(f"Client {client_name} finished training for round {input_model.current_round}")
if flare.is_train():
# Training task: receive global model, train, send update.
if input_model.params is None or NPConstants.NUMPY_KEY not in input_model.params:
raise RuntimeError(
"Train task received no model params (params is None or missing numpy_key). "
"Server requires a valid initial model; empty response would break aggregation."
)
input_np_arr = input_model.params[NPConstants.NUMPY_KEY]
print(f"Received weights: {input_np_arr}")
new_params = train(input_np_arr)
last_params = new_params
metrics = evaluate(new_params)
print(f"Client {client_name} evaluation metrics: {metrics}")
print(f"Client {client_name} finished training for round {input_model.current_round}")
if args.update_type == "diff":
params_to_send = new_params - input_np_arr
params_type = flare.ParamsType.DIFF
else:
params_to_send = new_params
params_type = flare.ParamsType.FULL
print(f"Sending weights: {params_to_send}")
flare.send(
flare.FLModel(
params={NPConstants.NUMPY_KEY: params_to_send},
params_type=params_type,
metrics=metrics,
current_round=input_model.current_round,
)
)

elif flare.is_evaluate():
# Validate task: evaluate the received model and send metrics only (no params).
if input_model.params is None or NPConstants.NUMPY_KEY not in input_model.params:
flare.send(flare.FLModel(metrics={}))
continue
input_np_arr = input_model.params[NPConstants.NUMPY_KEY]
metrics = evaluate(input_np_arr)
print(f"Client {client_name} validation metrics: {metrics}")
flare.send(flare.FLModel(metrics=metrics))

elif flare.is_submit_model():
# Submit local model for cross-site evaluation (must be WEIGHTS DXO).
if last_params is None:
raise RuntimeError(
"submit_model called but no local model (last_params) available. "
"CSE expects client weights; run training first or fix job order."
)
print(f"Client {client_name} submitting local model")
flare.send(
flare.FLModel(
params={NPConstants.NUMPY_KEY: last_params},
params_type=flare.ParamsType.FULL,
)
)

# Prepare parameters to send
if args.update_type == "diff":
params_to_send = new_params - input_np_arr
params_type = flare.ParamsType.DIFF
else:
params_to_send = new_params
params_type = flare.ParamsType.FULL

# Send updated model back to server
print(f"Sending weights: {params_to_send}")
output_model = flare.FLModel(
params={NPConstants.NUMPY_KEY: params_to_send},
params_type=params_type,
metrics=metrics,
current_round=input_model.current_round,
)

flare.send(output_model)
# Task is not train, evaluate, or submit_model (e.g. future task type or different job config).
# Send empty metrics so we always reply and the protocol does not hang.
flare.send(flare.FLModel(metrics={}))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_test/recipe_system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_dict_model_config_simulation(self):

# Use dict config instead of class instance
model_config = {
"class_path": "model.SimpleNetwork",
"path": "model.SimpleNetwork",
"args": {},
}

Expand Down
145 changes: 145 additions & 0 deletions tests/unit_test/recipe/hello_numpy_cross_val_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for hello-numpy-cross-val client: train/evaluate helpers and fail-fast behavior.

The client must:
- Raise RuntimeError when a train task has no params (no NUMPY_KEY), so the server
does not receive an empty response and break aggregation.
- Raise RuntimeError when submit_model is requested but last_params is missing,
so CSE does not silently submit wrong weights.

Note: Other tests in this folder (fedavg_recipe_test, fedopt_recipe_test, eval_recipe_test)
test recipe classes (job construction, config). This file tests an *example client script*
that uses the public nvflare.client API (imported as ``flare``). To run the client's main()
without a real FL runtime we replace ``flare`` with a mock object (mock_flare) that
provides receive(), is_train(), is_evaluate(), is_submit_model(), send(), etc., so we can
drive one loop iteration and assert fail-fast behavior. The client/in_process/api_test.py
tests the real InProcessClientAPI implementation; here we mock the API to test the
example's response to bad inputs.
"""

import os
import sys
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from nvflare.app_common.abstract.fl_model import FLModel


def _client_dir():
return os.path.join(
os.path.dirname(__file__),
"..",
"..",
"..",
"examples",
"hello-world",
"hello-numpy-cross-val",
)


def _import_client():
client_dir = os.path.abspath(_client_dir())
if client_dir not in sys.path:
sys.path.insert(0, client_dir)
import client as client_mod # noqa: E402

return client_mod


class TestHelloNumpyCrossValClientHelpers:
"""Test pure helper functions used by the client."""

def test_train_adds_one(self):
client_mod = _import_client()
x = np.array([0.0, 1.0, 2.0], dtype=np.float32)
out = client_mod.train(x)
np.testing.assert_array_almost_equal(out, np.array([1.0, 2.0, 3.0], dtype=np.float32))

def test_evaluate_returns_weight_mean(self):
client_mod = _import_client()
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
metrics = client_mod.evaluate(x)
assert "weight_mean" in metrics
assert metrics["weight_mean"] == 2.0


def _make_mock_flare(*, receive_return, is_train, is_evaluate, is_submit_model):
"""Build a mock for nvflare.client (flare) so the example client's main() can run one loop.

Used only in TestHelloNumpyCrossValClientFailFast. Other recipe tests in this folder
test recipe classes and use patch("os.path.*"); they do not run or mock client scripts.

is_running.side_effect = [True, False] so the while loop runs at most one iteration then
exits; avoids hanging if client logic stops raising.
"""
mock = MagicMock()
mock.init.return_value = None
mock.system_info.return_value = {"site_name": "site-1"}
mock.is_running.side_effect = [True, False]
mock.receive.return_value = receive_return
mock.is_train.return_value = is_train
mock.is_evaluate.return_value = is_evaluate
mock.is_submit_model.return_value = is_submit_model
mock.send.return_value = None
mock.ParamsType = MagicMock()
mock.ParamsType.FULL = "FULL"
mock.ParamsType.DIFF = "DIFF"
return mock


class TestHelloNumpyCrossValClientFailFast:
"""Test that the client raises RuntimeError instead of sending invalid responses."""

def test_train_task_with_no_params_raises(self):
"""Train task with params None or missing NUMPY_KEY must raise; empty response breaks aggregation."""
client_mod = _import_client()
mock_flare = _make_mock_flare(
receive_return=FLModel(params=None, current_round=0),
is_train=True,
is_evaluate=False,
is_submit_model=False,
)
with patch("sys.argv", ["client.py"]), patch.object(client_mod, "flare", mock_flare):
with pytest.raises(RuntimeError, match="Train task received no model params"):
client_mod.main()

def test_train_task_with_params_missing_numpy_key_raises(self):
"""Train task with params dict but no NUMPY_KEY must raise."""
client_mod = _import_client()
mock_flare = _make_mock_flare(
receive_return=FLModel(params={}, current_round=0),
is_train=True,
is_evaluate=False,
is_submit_model=False,
)
with patch("sys.argv", ["client.py"]), patch.object(client_mod, "flare", mock_flare):
with pytest.raises(RuntimeError, match="Train task received no model params"):
client_mod.main()

def test_submit_model_with_no_last_params_raises(self):
"""submit_model when last_params was never set must raise; otherwise wrong weights are submitted."""
client_mod = _import_client()
mock_flare = _make_mock_flare(
receive_return=FLModel(params=None, current_round=None),
is_train=False,
is_evaluate=False,
is_submit_model=True,
)
with patch("sys.argv", ["client.py"]), patch.object(client_mod, "flare", mock_flare):
with pytest.raises(RuntimeError, match="submit_model called but no local model"):
client_mod.main()