|
34 | 34 | import threading |
35 | 35 | import time |
36 | 36 | import unittest |
37 | | - |
| 37 | +from types import SimpleNamespace |
| 38 | +from unittest import mock |
| 39 | + |
| 40 | +from nvflare.apis.fl_constant import ReturnCode as FLReturnCode |
| 41 | +from nvflare.apis.fl_context import FLContextManager |
| 42 | +from nvflare.apis.shareable import Shareable, make_reply |
| 43 | +from nvflare.apis.signal import Signal |
| 44 | +from nvflare.app_common.abstract.learnable import Learnable |
| 45 | +from nvflare.app_common.app_constant import AppConstants |
| 46 | +from nvflare.app_common.ccwf.common import Constant |
| 47 | +from nvflare.app_common.ccwf.swarm_client_ctl import SwarmClientController |
38 | 48 | from nvflare.fuel.f3.cellnet.core_cell import CoreCell, Message, MessageHeaderKey, TargetMessage |
39 | 49 | from nvflare.fuel.f3.cellnet.defs import ReturnCode |
40 | 50 | from nvflare.fuel.utils.network_utils import get_open_ports |
@@ -415,5 +425,79 @@ def blocking_handler(message: Message): |
415 | 425 | self.assertTrue(deadlock_detected.is_set(), "Deadlock should be detected - tensor wait timed out") |
416 | 426 |
|
417 | 427 |
|
| 428 | +class TestSwarmResultSubmissionFix(unittest.TestCase): |
| 429 | + def test_local_submit_when_aggregator_is_self(self): |
| 430 | + class _DummyGatherer: |
| 431 | + def __init__(self, **kwargs): |
| 432 | + self.for_round = kwargs.get("for_round", 0) |
| 433 | + |
| 434 | + class _DummyEngine: |
| 435 | + def __init__(self): |
| 436 | + self.submit_req_calls = 0 |
| 437 | + |
| 438 | + def send_aux_request(self, **kwargs): |
| 439 | + self.submit_req_calls += 1 |
| 440 | + return {"site-1": make_reply(FLReturnCode.OK)} |
| 441 | + |
| 442 | + def new_context(self): |
| 443 | + return FLContextManager(engine=self, identity_name="site-1", job_id="job").new_context() |
| 444 | + |
| 445 | + engine = _DummyEngine() |
| 446 | + fl_ctx = FLContextManager(engine=engine, identity_name="site-1", job_id="job").new_context() |
| 447 | + abort_signal = Signal() |
| 448 | + |
| 449 | + task_data = Shareable() |
| 450 | + task_data.set_header(AppConstants.CURRENT_ROUND, 1) |
| 451 | + task_data.set_header(Constant.AGGREGATOR, "site-1") |
| 452 | + |
| 453 | + learn_result = make_reply(FLReturnCode.OK) |
| 454 | + |
| 455 | + ctl = object.__new__(SwarmClientController) |
| 456 | + ctl.me = "site-1" |
| 457 | + ctl.is_trainer = True |
| 458 | + ctl.gatherer = None |
| 459 | + ctl.gatherer_waiter = threading.Event() |
| 460 | + ctl.metric_comparator = object() |
| 461 | + ctl.trainers = ["site-1"] |
| 462 | + ctl.learn_task_timeout = 10 |
| 463 | + ctl.min_responses_required = 1 |
| 464 | + ctl.wait_time_after_min_resps_received = 0 |
| 465 | + ctl.aggregator = object() |
| 466 | + ctl.max_concurrent_submissions = 1 |
| 467 | + ctl.request_to_submit_result_max_wait = 10 |
| 468 | + ctl.request_to_submit_result_msg_timeout = 1 |
| 469 | + ctl.request_to_submit_result_interval = 0 |
| 470 | + ctl.request_to_submit_learn_result_task_name = "request_submit" |
| 471 | + ctl.report_learn_result_task_name = "report_result" |
| 472 | + ctl.learn_task_ack_timeout = 5 |
| 473 | + ctl.shareable_generator = SimpleNamespace(shareable_to_learnable=lambda _task_data, _ctx: Learnable()) |
| 474 | + ctl.get_config_prop = lambda key, default=None: ["site-1"] if key == Constant.CLIENTS else default |
| 475 | + ctl.execute_learn_task = lambda _task_data, _ctx, _abort_signal: learn_result |
| 476 | + ctl.is_task_secure = lambda _ctx: False |
| 477 | + ctl.update_status = lambda **kwargs: None |
| 478 | + ctl.fire_event = lambda *_args, **_kwargs: None |
| 479 | + ctl.log_info = lambda *_args, **_kwargs: None |
| 480 | + ctl.log_debug = lambda *_args, **_kwargs: None |
| 481 | + ctl.log_warning = lambda *_args, **_kwargs: None |
| 482 | + ctl.log_error = lambda *_args, **_kwargs: None |
| 483 | + ctl.broadcast_and_wait = mock.Mock( |
| 484 | + side_effect=AssertionError("broadcast_and_wait must not be called for local result submission") |
| 485 | + ) |
| 486 | + ctl._process_learn_result = mock.Mock(return_value=make_reply(FLReturnCode.OK)) |
| 487 | + |
| 488 | + with mock.patch("nvflare.app_common.ccwf.swarm_client_ctl.Gatherer", _DummyGatherer): |
| 489 | + ctl.do_learn_task("train", task_data, fl_ctx, abort_signal) |
| 490 | + |
| 491 | + ctl.broadcast_and_wait.assert_not_called() |
| 492 | + ctl._process_learn_result.assert_called_once() |
| 493 | + self.assertEqual(engine.submit_req_calls, 1, "submission permission request should still be sent once") |
| 494 | + |
| 495 | + called_result, called_fl_ctx, called_abort_signal = ctl._process_learn_result.call_args[0] |
| 496 | + self.assertIs(called_result, learn_result) |
| 497 | + self.assertIs(called_abort_signal, abort_signal) |
| 498 | + self.assertIsNot(called_fl_ctx, fl_ctx) |
| 499 | + self.assertEqual(called_fl_ctx.get_peer_context().get_identity_name(), "site-1") |
| 500 | + |
| 501 | + |
418 | 502 | if __name__ == "__main__": |
419 | 503 | unittest.main() |
0 commit comments