Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nvflare/edge/assessors/buff_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def remove_devices_from_used(self, devices: Set[str], fl_ctx) -> None:
self.used_devices.pop(device_id, None)

def has_enough_devices(self, fl_ctx) -> bool:
return len(self.available_devices) >= self.device_selection_size
num_holes = self.device_selection_size - len(self.current_selection)
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
num_usable_devices = len(usable_devices)
return num_usable_devices >= num_holes

def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
Expand Down
10 changes: 3 additions & 7 deletions nvflare/edge/assessors/buff_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,15 @@ def process_updates(self, model_updates: Dict[int, ModelUpdate], fl_ctx: FLConte
self.log_error(fl_ctx, f"bad child update version {model_version}: no update data")
continue

if self.current_model_version - model_version > self.max_model_history:
if self.current_model_version - model_version >= self.max_model_history:
# this version is too old
self.log_info(
self.log_warning(
fl_ctx,
f"dropped child update version {model_version}. Current version {self.current_model_version}",
f"dropped child update version {model_version}. Current version {self.current_model_version}. Max history {self.max_model_history}",
)
continue

model_state = self.updates.get(model_version)
if not model_state:
self.log_error(fl_ctx, f"No model state for version {model_version}")
continue

accepted = model_state.accept(model_update, fl_ctx)
self.log_info(
fl_ctx,
Expand Down
130 changes: 117 additions & 13 deletions nvflare/edge/assessors/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
model_manager_id,
device_manager_id,
max_model_version,
device_wait_timeout: Optional[float] = None,
device_status_log_interval: Optional[float] = 30.0,
):
"""Initialize the ModelUpdateAssessor.
Enable both asynchronous and synchronous model updates from clients.
Expand All @@ -44,6 +46,8 @@ def __init__(
model_manager_id (str): ID of the model manager component.
device_manager_id (str): ID of the device manager component.
max_model_version (int): Maximum model version to stop the workflow.
device_wait_timeout (float, optional): Timeout in seconds for waiting for sufficient devices. None means no timeout. Default is None.
device_status_log_interval (float, optional): Interval in seconds for logging device status. Default is 30 seconds.
"""
Assessor.__init__(self)
self.persistor_id = persistor_id
Expand All @@ -54,9 +58,78 @@ def __init__(
self.device_manager = None
self.max_model_version = max_model_version
self.update_lock = threading.Lock()
self.start_time = None
self.device_wait_timeout = device_wait_timeout
self.device_wait_start_time = None
self._last_device_status_log_time = time.time()
self.device_status_log_interval = device_status_log_interval
self.register_event_handler(EventType.START_RUN, self._handle_start_run)

def _is_device_wait_timeout_exceeded(self, fl_ctx: FLContext) -> bool:
"""Check if device wait timeout has been exceeded.

Args:
fl_ctx: FL context

Returns:
bool: True if timeout exceeded, False otherwise
"""
if self.device_wait_start_time is None or self.device_wait_timeout is None:
return False

try:
elapsed = time.time() - self.device_wait_start_time
if elapsed > self.device_wait_timeout:
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
)
self.log_error(
fl_ctx,
f"Device wait timeout ({self.device_wait_timeout}s) exceeded. "
f"Elapsed time: {elapsed:.1f}s. "
f"Total devices: {len(self.device_manager.available_devices)}, "
f"usable: {len(usable_devices)}, "
f"expected: {self.device_manager.device_selection_size}. "
f"Device_reuse flag: {self.device_manager.device_reuse}. "
"Stopping the job.",
)
return True

return False
except Exception as e:
self.log_error(fl_ctx, f"Error checking device timeout: {e}")
return False

def _log_device_status(self, fl_ctx: FLContext):
"""Log device status information independently of timeout logic."""
if self.device_status_log_interval is None:
return

current_time = time.time()
elapsed = current_time - self._last_device_status_log_time

if elapsed >= self.device_status_log_interval:
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
)

# Add timeout info if we're actually waiting with a timeout
timeout_msg = ""
if self.device_wait_start_time is not None and self.device_wait_timeout is not None:
remaining_time = self.device_wait_timeout - (current_time - self.device_wait_start_time)
timeout_msg = f" Timeout in {remaining_time:.1f} seconds."
elif self.device_wait_start_time is not None:
timeout_msg = " No timeout set (waiting indefinitely)."

self.log_info(
fl_ctx,
f"Device Status: "
f"Total: {len(self.device_manager.available_devices)}, "
f"usable: {len(usable_devices)}, "
f"expected: {self.device_manager.device_selection_size}.{timeout_msg}",
)

self._last_device_status_log_time = current_time

def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()

Expand Down Expand Up @@ -96,7 +169,6 @@ def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
return

def start_task(self, fl_ctx: FLContext) -> Shareable:
self.start_time = time.time()
# empty base state to start with
base_state = BaseState(
model_version=0,
Expand All @@ -116,6 +188,30 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
# Update available devices
if report.available_devices:
self.device_manager.update_available_devices(report.available_devices, fl_ctx)
# Reset wait timer if we now have enough devices
if self.device_wait_start_time is not None and self.device_manager.has_enough_devices(fl_ctx):
self.device_wait_start_time = None
self.log_info(fl_ctx, "Sufficient devices now available, resetting wait timer")

# Check for device wait timeout if we are waiting for devices
if self.device_wait_start_time is not None and self._is_device_wait_timeout_exceeded(fl_ctx):
# Timeout exceeded, prepare an empty reply and stop the job
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
)
self.log_error(
fl_ctx,
f"Total devices: {len(self.device_manager.available_devices)}, usable: {len(usable_devices)}, expected: {self.device_manager.device_selection_size}. "
f"Device_reuse flag is set to: {self.device_manager.device_reuse}. "
"Not enough devices joining, please adjust the server params. Stopping the job.",
)
reply = StateUpdateReply(
model_version=0,
model=None,
device_selection_version=self.device_manager.current_selection_version,
device_selection=self.device_manager.get_selection(fl_ctx),
)
return False, reply.to_shareable()

accepted = True
if report.model_updates:
Expand All @@ -133,21 +229,24 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
if self.device_manager.device_reuse:
self.device_manager.remove_devices_from_used(set(model_update.devices.keys()), fl_ctx)

# Handle device selection
if self.device_manager.should_fill_selection(fl_ctx):
self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx)

else:
self.log_debug(fl_ctx, "no model updates")

# Handle initial model generation
if self.model_manager.current_model_version == 0:
# Handle device selection
if self.device_manager.should_fill_selection(fl_ctx):
# check if we have enough devices to fill selection
if self.device_manager.has_enough_devices(fl_ctx):
self.log_info(
fl_ctx, f"got {len(self.device_manager.available_devices)} devices - generate initial model"
)
self.model_manager.generate_new_model(fl_ctx)
if self.model_manager.current_model_version == 0:
self.log_info(fl_ctx, "Generate initial model and fill selection")
self.model_manager.generate_new_model(fl_ctx)
self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx)
# Reset wait timer since we have enough devices
self.device_wait_start_time = None
else:
# Start wait timer if not already started
if self.device_wait_start_time is None:
self.device_wait_start_time = time.time()
self.log_info(fl_ctx, f"Starting device wait timer (timeout: {self.device_wait_timeout}s)")

# Prepare reply
model = None
Expand All @@ -163,7 +262,12 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
return accepted, reply.to_shareable()

def assess(self, fl_ctx: FLContext) -> Assessment:
if self.model_manager.current_model_version >= self.max_model_version:
# Check if we're waiting for devices and timeout exceeded
self._log_device_status(fl_ctx)
if self._is_device_wait_timeout_exceeded(fl_ctx):
self.log_error(fl_ctx, "Job stopped due to insufficient devices joining within timeout period")
return Assessment.WORKFLOW_DONE
elif self.model_manager.current_model_version >= self.max_model_version:
model_version = self.model_manager.current_model_version
self.log_info(fl_ctx, f"Max model version {self.max_model_version} reached: {model_version=}")
return Assessment.WORKFLOW_DONE
Expand Down
Loading