Skip to content

Commit 902bf98

Browse files
authored
Cherry-pick NVIDIA#4174 and NVIDIA#4186 from 2.7 to main (NVIDIA#4193)
## Summary - Cherry-pick of NVIDIA#4174: reduce lock scope in `Cacheable._get_item` — `produce_item` now runs outside the lock so concurrent receivers aren't blocked - Cherry-pick of NVIDIA#4186: avoid self-message deadlock when swarm trainer submits learn result to itself — local submission bypasses `broadcast_and_wait`, adds unit test coverage
1 parent ecf608c commit 902bf98

File tree

3 files changed

+130
-23
lines changed

3 files changed

+130
-23
lines changed

nvflare/app_common/ccwf/swarm_client_ctl.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -845,23 +845,31 @@ def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abor
845845
time.sleep(self.request_to_submit_result_interval)
846846

847847
# send the result to the aggr
848-
self.log_info(fl_ctx, f"sending training result to aggregation client {aggr}")
848+
if aggr == self.me:
849+
# Avoid synchronous self-message path through CoreCell._send_direct_message.
850+
self.log_info(fl_ctx, "submitting training result locally (aggregation client is self)")
851+
engine = fl_ctx.get_engine()
852+
local_fl_ctx = fl_ctx.clone()
853+
local_fl_ctx.set_peer_context(engine.new_context())
854+
reply = self._process_learn_result(result, local_fl_ctx, abort_signal)
855+
else:
856+
self.log_info(fl_ctx, f"sending training result to aggregation client {aggr}")
849857

850-
task = Task(
851-
name=self.report_learn_result_task_name,
852-
data=result,
853-
timeout=int(self.learn_task_ack_timeout),
854-
secure=self.is_task_secure(fl_ctx),
855-
)
858+
task = Task(
859+
name=self.report_learn_result_task_name,
860+
data=result,
861+
timeout=int(self.learn_task_ack_timeout),
862+
secure=self.is_task_secure(fl_ctx),
863+
)
856864

857-
resp = self.broadcast_and_wait(
858-
task=task,
859-
targets=[aggr],
860-
min_responses=1,
861-
fl_ctx=fl_ctx,
862-
)
865+
resp = self.broadcast_and_wait(
866+
task=task,
867+
targets=[aggr],
868+
min_responses=1,
869+
fl_ctx=fl_ctx,
870+
)
863871

864-
reply = resp.get(aggr)
872+
reply = resp.get(aggr)
865873
if not reply:
866874
self.log_error(fl_ctx, f"failed to receive reply from aggregation client: {aggr}")
867875
self.update_status(action="receive_learn_result_reply", error=ReturnCode.EXECUTION_EXCEPTION)

nvflare/fuel/f3/streaming/cacheable.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,35 @@ def clear_cache(self):
9090
def _get_item(self, index: int, requester: str) -> bytes:
9191
with self.lock:
9292
if not self.cache:
93-
# the cache has been cleared
93+
cache_available = False
9494
data = None
9595
else:
96+
cache_available = True
9697
data, _ = self.cache[index]
9798

98-
if data is None:
99-
data = self.produce_item(index)
100-
if self.cache:
101-
self.cache[index] = (data, 0)
102-
self.logger.debug(f"created and cached item {index} for {requester}: {len(data)} bytes")
103-
else:
104-
self.logger.debug(f"got item {index} from cache for {requester}")
99+
if not cache_available:
100+
return self.produce_item(index)
101+
102+
if data is not None:
103+
self.logger.debug(f"got item {index} from cache for {requester}")
105104
return data
106105

106+
# Produce outside the lock so concurrent receivers aren't blocked.
107+
# If two receivers produce the same item simultaneously, the first
108+
# to re-acquire the lock stores its result; the second uses it.
109+
data = self.produce_item(index)
110+
111+
with self.lock:
112+
if self.cache:
113+
existing, count = self.cache[index]
114+
if existing is None:
115+
self.cache[index] = (data, count)
116+
self.logger.debug(f"created and cached item {index} for {requester}: {len(data)} bytes")
117+
else:
118+
data = existing
119+
self.logger.debug(f"got item {index} from cache for {requester} (produced concurrently)")
120+
return data
121+
107122
def _adjust_cache(self, start: int, count: int):
108123
with self.lock:
109124
if not self.cache:

tests/unit_test/app_common/ccwf/test_swarm_self_message_deadlock.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,17 @@
3434
import threading
3535
import time
3636
import unittest
37-
37+
from types import SimpleNamespace
38+
from unittest import mock
39+
40+
from nvflare.apis.fl_constant import ReturnCode as FLReturnCode
41+
from nvflare.apis.fl_context import FLContextManager
42+
from nvflare.apis.shareable import Shareable, make_reply
43+
from nvflare.apis.signal import Signal
44+
from nvflare.app_common.abstract.learnable import Learnable
45+
from nvflare.app_common.app_constant import AppConstants
46+
from nvflare.app_common.ccwf.common import Constant
47+
from nvflare.app_common.ccwf.swarm_client_ctl import SwarmClientController
3848
from nvflare.fuel.f3.cellnet.core_cell import CoreCell, Message, MessageHeaderKey, TargetMessage
3949
from nvflare.fuel.f3.cellnet.defs import ReturnCode
4050
from nvflare.fuel.utils.network_utils import get_open_ports
@@ -415,5 +425,79 @@ def blocking_handler(message: Message):
415425
self.assertTrue(deadlock_detected.is_set(), "Deadlock should be detected - tensor wait timed out")
416426

417427

428+
class TestSwarmResultSubmissionFix(unittest.TestCase):
429+
def test_local_submit_when_aggregator_is_self(self):
430+
class _DummyGatherer:
431+
def __init__(self, **kwargs):
432+
self.for_round = kwargs.get("for_round", 0)
433+
434+
class _DummyEngine:
435+
def __init__(self):
436+
self.submit_req_calls = 0
437+
438+
def send_aux_request(self, **kwargs):
439+
self.submit_req_calls += 1
440+
return {"site-1": make_reply(FLReturnCode.OK)}
441+
442+
def new_context(self):
443+
return FLContextManager(engine=self, identity_name="site-1", job_id="job").new_context()
444+
445+
engine = _DummyEngine()
446+
fl_ctx = FLContextManager(engine=engine, identity_name="site-1", job_id="job").new_context()
447+
abort_signal = Signal()
448+
449+
task_data = Shareable()
450+
task_data.set_header(AppConstants.CURRENT_ROUND, 1)
451+
task_data.set_header(Constant.AGGREGATOR, "site-1")
452+
453+
learn_result = make_reply(FLReturnCode.OK)
454+
455+
ctl = object.__new__(SwarmClientController)
456+
ctl.me = "site-1"
457+
ctl.is_trainer = True
458+
ctl.gatherer = None
459+
ctl.gatherer_waiter = threading.Event()
460+
ctl.metric_comparator = object()
461+
ctl.trainers = ["site-1"]
462+
ctl.learn_task_timeout = 10
463+
ctl.min_responses_required = 1
464+
ctl.wait_time_after_min_resps_received = 0
465+
ctl.aggregator = object()
466+
ctl.max_concurrent_submissions = 1
467+
ctl.request_to_submit_result_max_wait = 10
468+
ctl.request_to_submit_result_msg_timeout = 1
469+
ctl.request_to_submit_result_interval = 0
470+
ctl.request_to_submit_learn_result_task_name = "request_submit"
471+
ctl.report_learn_result_task_name = "report_result"
472+
ctl.learn_task_ack_timeout = 5
473+
ctl.shareable_generator = SimpleNamespace(shareable_to_learnable=lambda _task_data, _ctx: Learnable())
474+
ctl.get_config_prop = lambda key, default=None: ["site-1"] if key == Constant.CLIENTS else default
475+
ctl.execute_learn_task = lambda _task_data, _ctx, _abort_signal: learn_result
476+
ctl.is_task_secure = lambda _ctx: False
477+
ctl.update_status = lambda **kwargs: None
478+
ctl.fire_event = lambda *_args, **_kwargs: None
479+
ctl.log_info = lambda *_args, **_kwargs: None
480+
ctl.log_debug = lambda *_args, **_kwargs: None
481+
ctl.log_warning = lambda *_args, **_kwargs: None
482+
ctl.log_error = lambda *_args, **_kwargs: None
483+
ctl.broadcast_and_wait = mock.Mock(
484+
side_effect=AssertionError("broadcast_and_wait must not be called for local result submission")
485+
)
486+
ctl._process_learn_result = mock.Mock(return_value=make_reply(FLReturnCode.OK))
487+
488+
with mock.patch("nvflare.app_common.ccwf.swarm_client_ctl.Gatherer", _DummyGatherer):
489+
ctl.do_learn_task("train", task_data, fl_ctx, abort_signal)
490+
491+
ctl.broadcast_and_wait.assert_not_called()
492+
ctl._process_learn_result.assert_called_once()
493+
self.assertEqual(engine.submit_req_calls, 1, "submission permission request should still be sent once")
494+
495+
called_result, called_fl_ctx, called_abort_signal = ctl._process_learn_result.call_args[0]
496+
self.assertIs(called_result, learn_result)
497+
self.assertIs(called_abort_signal, abort_signal)
498+
self.assertIsNot(called_fl_ctx, fl_ctx)
499+
self.assertEqual(called_fl_ctx.get_peer_context().get_identity_name(), "site-1")
500+
501+
418502
if __name__ == "__main__":
419503
unittest.main()

0 commit comments

Comments
 (0)