Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nvflare/edge/assessors/buff_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def remove_devices_from_used(self, devices: Set[str], fl_ctx) -> None:
self.used_devices.pop(device_id, None)

def has_enough_devices(self, fl_ctx) -> bool:
return len(self.available_devices) >= self.device_selection_size
num_holes = self.device_selection_size - len(self.current_selection)
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
num_usable_devices = len(usable_devices)
return num_usable_devices >= num_holes

def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
Expand Down
10 changes: 3 additions & 7 deletions nvflare/edge/assessors/buff_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,15 @@ def process_updates(self, model_updates: Dict[int, ModelUpdate], fl_ctx: FLConte
self.log_error(fl_ctx, f"bad child update version {model_version}: no update data")
continue

if self.current_model_version - model_version > self.max_model_history:
if self.current_model_version - model_version >= self.max_model_history:
# this version is too old
self.log_info(
self.log_warning(
fl_ctx,
f"dropped child update version {model_version}. Current version {self.current_model_version}",
f"dropped child update version {model_version}. Current version {self.current_model_version}. Max history {self.max_model_history}",
)
continue

model_state = self.updates.get(model_version)
if not model_state:
self.log_error(fl_ctx, f"No model state for version {model_version}")
continue

accepted = model_state.accept(model_update, fl_ctx)
self.log_info(
fl_ctx,
Expand Down
200 changes: 187 additions & 13 deletions nvflare/edge/assessors/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

from nvflare.apis.event_type import EventType
Expand All @@ -33,6 +34,7 @@ def __init__(
model_manager_id,
device_manager_id,
max_model_version,
device_wait_timeout: float = 30.0,
):
"""Initialize the ModelUpdateAssessor.
Enable both asynchronous and synchronous model updates from clients.
Expand All @@ -44,6 +46,7 @@ def __init__(
model_manager_id (str): ID of the model manager component.
device_manager_id (str): ID of the device manager component.
max_model_version (int): Maximum model version to stop the workflow.
device_wait_timeout (float): Timeout in seconds for waiting for sufficient devices. Default is 30 seconds.
"""
Assessor.__init__(self)
self.persistor_id = persistor_id
Expand All @@ -53,10 +56,132 @@ def __init__(
self.model_manager = None
self.device_manager = None
self.max_model_version = max_model_version
self.device_wait_timeout = device_wait_timeout
self.update_lock = threading.Lock()
self.start_time = None
self.device_wait_start_time = None
self.should_stop_job = False
self.timeout_future = None
self.thread_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="DeviceTimeout")
self.register_event_handler(EventType.START_RUN, self._handle_start_run)

def _check_device_timeout(self, fl_ctx: FLContext) -> bool:
"""Check if device wait timeout has been exceeded.

Args:
fl_ctx: FL context

Returns:
bool: True if timeout exceeded, False otherwise
"""
if self.device_wait_start_time is None:
return False

if time.time() - self.device_wait_start_time > self.device_wait_timeout:
return True
return False

def _log_device_wait_status(self, fl_ctx: FLContext, message_prefix: str = ""):
"""Log current device wait status with countdown information."""
if self.device_wait_start_time is not None:
current_time = time.time()
elapsed = current_time - self.device_wait_start_time

# Only log if we haven't logged recently (rate limiting)
if not hasattr(self, "_last_status_log_time"):
self._last_status_log_time = 0

# Adaptive logging frequency based on urgency
remaining_time = self.device_wait_timeout - elapsed
if remaining_time > 0:
# Determine logging interval based on remaining time
if remaining_time > 60: # More than 1 minute: log every 30 seconds
log_interval = 30.0
elif remaining_time > 30: # 30 seconds to 1 minute: log every 15 seconds
log_interval = 15.0
elif remaining_time > 10: # 10 to 30 seconds: log every 5 seconds
log_interval = 5.0
else: # Final 10 seconds: log every 2 seconds
log_interval = 2.0

# Check if enough time has passed since last log
if current_time - self._last_status_log_time >= log_interval:
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
)
self.log_info(
fl_ctx,
f"{message_prefix}Device wait status: "
f"Total devices: {len(self.device_manager.available_devices)}, "
f"usable: {len(usable_devices)}, "
f"expected: {self.device_manager.device_selection_size}. "
f"Timeout in {remaining_time:.1f} seconds.",
)
self._last_status_log_time = current_time

def _start_timeout_tracking(self, fl_ctx: FLContext):
"""Start independent timeout tracking using thread pool."""
if self.timeout_future is not None and not self.timeout_future.done():
# Cancel existing timeout if still running
self.timeout_future.cancel()

# Submit new timeout task to thread pool
self.timeout_future = self.thread_pool.submit(self._timeout_tracker, fl_ctx)
self.log_debug(fl_ctx, f"Started device wait timeout tracking for {self.device_wait_timeout}s")

def _stop_timeout_tracking(self):
"""Stop timeout tracking."""
if self.timeout_future is not None and not self.timeout_future.done():
self.timeout_future.cancel()
self.timeout_future = None

def _timeout_tracker(self, fl_ctx: FLContext):
"""Independent timeout tracker that runs in thread pool."""
try:
# Periodic logging during countdown to keep users informed
check_interval = 10.0 # Log every 10 seconds
elapsed = 0

while (
self.device_wait_start_time is not None
and not self.should_stop_job
and elapsed < self.device_wait_timeout
):

time.sleep(min(check_interval, self.device_wait_timeout - elapsed))
elapsed += check_interval

# Check if we should stop monitoring
if self.device_wait_start_time is None or self.should_stop_job:
return

# Log periodic status update
if elapsed < self.device_wait_timeout:
remaining = self.device_wait_timeout - elapsed
self.log_info(fl_ctx, f"Device wait countdown: {remaining:.1f} seconds remaining")

# Check if we're still waiting for devices and timeout exceeded
with self.update_lock:
if (
self.device_wait_start_time is not None
and not self.should_stop_job
and self._check_device_timeout(fl_ctx)
):

# Timeout exceeded, set the stop flag
self.should_stop_job = True
self.log_error(
fl_ctx,
f"Device wait timeout ({self.device_wait_timeout}s) exceeded. "
f"Setting stop job flag to terminate workflow.",
)

except Exception as e:
# Log error but don't crash the thread
self.log_error(fl_ctx, f"Error in timeout tracker: {e}")
finally:
# Clean up
self.timeout_future = None

def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()

Expand Down Expand Up @@ -96,7 +221,6 @@ def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
return

def start_task(self, fl_ctx: FLContext) -> Shareable:
self.start_time = time.time()
# empty base state to start with
base_state = BaseState(
model_version=0,
Expand All @@ -116,6 +240,38 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
# Update available devices
if report.available_devices:
self.device_manager.update_available_devices(report.available_devices, fl_ctx)
# Reset wait timer if we now have enough devices
if self.device_wait_start_time is not None and self.device_manager.has_enough_devices(fl_ctx):
self.device_wait_start_time = None
self.should_stop_job = False # Reset stop job flag
# Stop timeout tracking since we have enough devices
self._stop_timeout_tracking()
self.log_info(fl_ctx, "Sufficient devices now available, resetting wait timer and stop job flag")

# Check for device wait timeout if we are waiting for devices
if self.device_wait_start_time is not None:
if self.should_stop_job:
# Timeout exceeded, prepare an empty reply
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
)
self.log_error(
fl_ctx,
f"Device wait timeout ({self.device_wait_timeout}s) exceeded. "
f"Total devices: {len(self.device_manager.available_devices)}, usable: {len(usable_devices)}, expected: {self.device_manager.device_selection_size}. "
f"Device_reuse flag is set to: {self.device_manager.device_reuse}. "
"Not enough devices joining, please adjust the server params. Stopping the job.",
)
reply = StateUpdateReply(
model_version=0,
model=None,
device_selection_version=self.device_manager.current_selection_version,
device_selection=self.device_manager.get_selection(fl_ctx),
)
return False, reply.to_shareable()
else:
# Log current wait status for user information (rate limited)
self._log_device_wait_status(fl_ctx, "Waiting for devices: ")

accepted = True
if report.model_updates:
Expand All @@ -133,21 +289,27 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
if self.device_manager.device_reuse:
self.device_manager.remove_devices_from_used(set(model_update.devices.keys()), fl_ctx)

# Handle device selection
if self.device_manager.should_fill_selection(fl_ctx):
self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx)

else:
self.log_debug(fl_ctx, "no model updates")

# Handle initial model generation
if self.model_manager.current_model_version == 0:
# Handle device selection
if self.device_manager.should_fill_selection(fl_ctx):
# check if we have enough devices to fill selection
if self.device_manager.has_enough_devices(fl_ctx):
self.log_info(
fl_ctx, f"got {len(self.device_manager.available_devices)} devices - generate initial model"
)
self.model_manager.generate_new_model(fl_ctx)
if self.model_manager.current_model_version == 0:
self.log_info(fl_ctx, "Generate initial model and fill selection")
self.model_manager.generate_new_model(fl_ctx)
self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx)
# Reset wait timer since we have enough devices
self.device_wait_start_time = None
# Stop timeout tracking since we have enough devices
self._stop_timeout_tracking()
else:
# Start or continue wait timer since we don't have enough devices
if self.device_wait_start_time is None:
self.device_wait_start_time = time.time()
# Start independent timeout tracking
self._start_timeout_tracking(fl_ctx)

# Prepare reply
model = None
Expand All @@ -163,9 +325,21 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
return accepted, reply.to_shareable()

def assess(self, fl_ctx: FLContext) -> Assessment:
if self.model_manager.current_model_version >= self.max_model_version:
if self.should_stop_job:
# Stop timeout tracking before ending the job
self._stop_timeout_tracking()
self.log_error(fl_ctx, "Job stopped due to insufficient devices joining within timeout period")
return Assessment.WORKFLOW_DONE
elif self.model_manager.current_model_version >= self.max_model_version:
# Stop timeout tracking before ending the job
self._stop_timeout_tracking()
model_version = self.model_manager.current_model_version
self.log_info(fl_ctx, f"Max model version {self.max_model_version} reached: {model_version=}")
return Assessment.WORKFLOW_DONE
else:
return Assessment.CONTINUE

def __del__(self):
"""Cleanup thread pool on destruction."""
if hasattr(self, "thread_pool"):
self.thread_pool.shutdown(wait=False)
Loading