Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions examples/advanced/edge/jobs/pt_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def create_edge_recipe(fl_mode, devices_per_leaf, num_leaf_nodes, global_rounds,
# need all devices to train for one global model version
num_updates_for_model=total_devices,
max_model_version=global_rounds,
# basic synchronous mode, no need to discard old model updates
max_model_history=1,
)
device_manager_config = DeviceManagerConfig(
# each leaf node has devices_per_leaf devices
Expand All @@ -92,16 +90,6 @@ def create_edge_recipe(fl_mode, devices_per_leaf, num_leaf_nodes, global_rounds,
# sync - each global model covers total_devices data
# async - each global model covers 1 device's data
max_model_version=global_rounds * total_devices,
# basic async mode, set max model update version diff so that
# the updater will not discard old model updates
# since the fastest device is 4 times faster than the slowest device,
# worst case is that there is only 1 slowest device and (total_devices - 1) fastest devices,
# to ensure that the updater will not discard old model updates,
# we need to allow (total_devices - 1) * 4 model updates
# on server side:
max_model_history=(total_devices - 1) * 4,
# on client/updater side:
max_num_active_model_versions=(total_devices - 1) * 4,
# increase the update timeout to allow for the slowest device to finish
update_timeout=500,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/edge/jobs/pt_job_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
global_lr = 0.1
num_updates_for_model = 20
max_model_version = 200
max_model_history = 100
max_model_history = None
min_hole_to_fill = 10
eval_frequency = 1
local_batch_size = 10
Expand Down
16 changes: 13 additions & 3 deletions nvflare/edge/assessors/buff_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, Set

from nvflare.edge.assessors.device_manager import DeviceManager
from nvflare.fuel.utils.validation_utils import check_positive_int


class BuffDeviceManager(DeviceManager):
Expand All @@ -39,12 +40,14 @@ def __init__(
device_reuse (bool): Whether to allow reusing devices across different model versions. Defaults to True.
"""
super().__init__()
check_positive_int("device_selection_size", device_selection_size)
check_positive_int("min_hole_to_fill", min_hole_to_fill)

self.device_selection_size = device_selection_size
self.min_hole_to_fill = min_hole_to_fill
self.device_reuse = device_reuse
# also keep track of the current selection version and used devices
self.current_selection_version = 0
self.used_devices = {}

def update_available_devices(self, devices: Dict, fl_ctx) -> None:
self.available_devices.update(devices)
Expand All @@ -66,8 +69,12 @@ def fill_selection(self, current_model_version: int, fl_ctx) -> None:
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
self.current_selection[device_id] = self.current_selection_version
self.used_devices[device_id] = current_model_version
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = current_model_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
self.log_info(
Expand Down Expand Up @@ -97,3 +104,6 @@ def has_enough_devices(self, fl_ctx) -> bool:
def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
return num_holes >= self.min_hole_to_fill

def get_active_model_versions(self, fl_ctx) -> Set[int]:
return set(self.current_selection.values())
56 changes: 37 additions & 19 deletions nvflare/edge/assessors/buff_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import time
from typing import Dict
from typing import Dict, Optional, Set

import numpy as np

Expand Down Expand Up @@ -43,16 +43,22 @@ class BuffModelManager(ModelManager):
def __init__(
self,
num_updates_for_model: int,
max_model_history: int,
max_model_history: Optional[int] = None,
global_lr: float = 1.0,
staleness_weight: bool = False,
):
"""Initialize the ModelManager.
The aggregation scheme and weights are calculated following FedBuff paper "Federated Learning with Buffered Asynchronous Aggregation".
The staleness_weight can be enabled to apply staleness weighting to model updates.

Special cases for max_model_history:
- If None: Keep every model versions, only remove a version when all devices processing it reports back (version no longer related with any device_id in the current_selection from device_manager).

Args:
num_updates_for_model (int): Number of updates required before generating a new model version.
max_model_history (int): Maximum number of historical model versions to keep in memory.
- None (default): keep every version until all devices processing a particular version report back.
- positive integer: keep only the latest n versions
global_lr (float): Global learning rate for model aggregation, default is 1.0.
staleness_weight (bool): Whether to apply staleness weighting to model updates, default is False.
"""
Expand All @@ -69,11 +75,29 @@ def initialize_model(self, model: DXO, fl_ctx: FLContext):
# updates is a dict of model version to _ModelState
self.updates[self.current_model_version] = _ModelState(ModelUpdateDXOAggregator())

def prune_model_versions(self, versions_to_keep: Set[int], fl_ctx: FLContext) -> None:
# go through all versions and remove the ones:
# - either not in versions_to_keep
# - or too old (current_model_version - v >= max_model_history)
versions_to_remove = set()

for v in self.updates.keys():
if v not in versions_to_keep:
versions_to_remove.add(v)
if self.max_model_history and self.current_model_version - v >= self.max_model_history:
versions_to_remove.add(v)

# Remove the identified versions
for v in versions_to_remove:
self.log_info(fl_ctx, f"removed model version {v}")
self.updates.pop(v)
# log the current total number of model versions
self.log_info(fl_ctx, f"current total number of active model versions: {len(self.updates)}")

def generate_new_model(self, fl_ctx: FLContext) -> None:
# New model generated based on the current global weights and all updates
new_model = {}
self.current_model_version += 1
old_model_versions = []

# counter to confirm the number of updates
num_updates = 0
Expand Down Expand Up @@ -103,9 +127,6 @@ def generate_new_model(self, fl_ctx: FLContext) -> None:
new_model[key] = value
else:
new_model[key] = new_model[key] + value
# If too old, remove it
if self.current_model_version - v >= self.max_model_history:
old_model_versions.append(v)

# Reset aggr after counting its contribution
ms.aggregator.reset(fl_ctx)
Expand All @@ -123,12 +144,6 @@ def generate_new_model(self, fl_ctx: FLContext) -> None:
self.updates[self.current_model_version] = _ModelState(ModelUpdateDXOAggregator())
self.log_info(fl_ctx, f"generated new model version {self.current_model_version} with {num_updates} updates")

if old_model_versions:
self.log_info(fl_ctx, f"removed old model versions {old_model_versions}")

for v in old_model_versions:
self.updates.pop(v)

# update the current model
# convert new_model items from numpy arrays to lists for serialization
new_model = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in new_model.items()}
Expand All @@ -154,14 +169,17 @@ def process_updates(self, model_updates: Dict[int, ModelUpdate], fl_ctx: FLConte
self.log_error(fl_ctx, f"bad child update version {model_version}: no update data")
continue

if self.current_model_version - model_version >= self.max_model_history:
# this version is too old
self.log_warning(
fl_ctx,
f"dropped child update version {model_version}. Current version {self.current_model_version}. Max history {self.max_model_history}",
)
continue
# Check if version is too old before accepting
if self.max_model_history:
# if max_model_history is set, output warning for updates that are too old
if self.current_model_version - model_version >= self.max_model_history:
self.log_warning(
fl_ctx,
f"dropped child update version {model_version}. Current version {self.current_model_version}. Max history {self.max_model_history}",
)
continue

# Accept the update and aggregate it to the corresponding model version
model_state = self.updates.get(model_version)
accepted = model_state.accept(model_update, fl_ctx)
self.log_info(
Expand Down
38 changes: 37 additions & 1 deletion nvflare/edge/assessors/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Set

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
Expand All @@ -32,9 +32,11 @@ def __init__(self):
DeviceManager keeps track of two dicts:
- current_selection for devices of current task distribution
- available_devices containing all devices that are available for selection
- used_devices dict kept for record keeping, containing all devices that have participated
"""
self.current_selection = {}
self.available_devices = {}
self.used_devices = {}

@abstractmethod
def update_available_devices(self, devices: Dict, fl_ctx: FLContext) -> None:
Expand Down Expand Up @@ -112,6 +114,18 @@ def has_enough_devices(self, fl_ctx: FLContext) -> bool:
"""
pass

@abstractmethod
def get_active_model_versions(self, fl_ctx: FLContext) -> Set[int]:
"""Get the active model versions that is associated with the current selection.

Args:
fl_ctx: FLContext object

Returns:
Set of active model versions
"""
pass

def get_selection(self, fl_ctx: FLContext) -> Any:
"""Get the current device selection.

Expand All @@ -122,3 +136,25 @@ def get_selection(self, fl_ctx: FLContext) -> Any:
Current device selection
"""
return self.current_selection

def get_available_devices(self, fl_ctx: FLContext) -> Set[str]:
"""Get the available devices.

Args:
fl_ctx: FLContext object

Returns:
Set of available devices
"""
return self.available_devices

def get_used_devices(self, fl_ctx: FLContext) -> Set[str]:
"""Get the used devices.

Args:
fl_ctx: FLContext object

Returns:
Set of used devices
"""
return self.used_devices
14 changes: 13 additions & 1 deletion nvflare/edge/assessors/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Set

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -62,6 +62,18 @@ def generate_new_model(self, fl_ctx: FLContext) -> None:
"""
pass

@abstractmethod
def prune_model_versions(self, versions_to_keep: Set[int], fl_ctx: FLContext) -> None:
"""Prune the model versions that are no longer active.

Args:
versions_to_keep: Set of model versions to keep
fl_ctx: FLContext object

Returns: none
"""
pass

@abstractmethod
def process_updates(self, model_updates: Any, fl_ctx: FLContext) -> bool:
"""Process incoming model updates from clients.
Expand Down
17 changes: 10 additions & 7 deletions nvflare/edge/assessors/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def _is_device_wait_timeout_exceeded(self, fl_ctx: FLContext) -> bool:
try:
elapsed = time.time() - self.device_wait_start_time
if elapsed > self.device_wait_timeout:
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set(
self.device_manager.get_used_devices(fl_ctx).keys()
)
self.log_error(
fl_ctx,
f"Device wait timeout ({self.device_wait_timeout}s) exceeded. "
f"Elapsed time: {elapsed:.1f}s. "
f"Total devices: {len(self.device_manager.available_devices)}, "
f"Total devices: {len(self.device_manager.get_available_devices(fl_ctx))}, "
f"usable: {len(usable_devices)}, "
f"expected: {self.device_manager.device_selection_size}. "
f"Device_reuse flag: {self.device_manager.device_reuse}. "
Expand All @@ -108,8 +108,8 @@ def _log_device_status(self, fl_ctx: FLContext):
elapsed = current_time - self._last_device_status_log_time

if elapsed >= self.device_status_log_interval:
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set(
self.device_manager.get_used_devices(fl_ctx).keys()
)

# Add timeout info if we're actually waiting with a timeout
Expand Down Expand Up @@ -196,8 +196,8 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
# Check for device wait timeout if we are waiting for devices
if self.device_wait_start_time is not None and self._is_device_wait_timeout_exceeded(fl_ctx):
# Timeout exceeded, prepare an empty reply and stop the job
usable_devices = set(self.device_manager.available_devices.keys()) - set(
self.device_manager.used_devices.keys()
usable_devices = set(self.device_manager.get_available_devices(fl_ctx).keys()) - set(
self.device_manager.get_used_devices(fl_ctx).keys()
)
self.log_error(
fl_ctx,
Expand Down Expand Up @@ -240,6 +240,9 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
self.log_info(fl_ctx, "Generate initial model and fill selection")
self.model_manager.generate_new_model(fl_ctx)
self.device_manager.fill_selection(self.model_manager.current_model_version, fl_ctx)
# prune old model versions that are no longer active
active_model_versions = self.device_manager.get_active_model_versions(fl_ctx)
self.model_manager.prune_model_versions(active_model_versions, fl_ctx)
# Reset wait timer since we have enough devices
self.device_wait_start_time = None
else:
Expand Down
8 changes: 7 additions & 1 deletion nvflare/edge/executors/edge_model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@

class EdgeModelExecutor(EdgeTaskExecutor):

def __init__(self, aggr_factory_id: str, max_model_versions: int, update_timeout=60):
def __init__(
self,
aggr_factory_id: str,
max_model_versions: Optional[int] = None,
update_timeout=60.0,
):
EdgeTaskExecutor.__init__(self, "", update_timeout)
self.aggr_factory_id = aggr_factory_id
self.max_model_versions = max_model_versions

self.cvt_lock = threading.Lock()

def get_updater(self, fl_ctx: FLContext):
Expand Down
8 changes: 6 additions & 2 deletions nvflare/edge/tools/edge_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import json
import os.path
from typing import Optional

from nvflare.edge.assessor import Assessor
from nvflare.edge.controllers.sage import ScatterAndGatherForEdge
Expand Down Expand Up @@ -94,7 +95,7 @@ def configure_server(
def configure_client(
self,
aggregator_factory: AggregatorFactory,
max_model_versions: int,
max_model_versions: Optional[int] = None,
update_timeout=5.0,
executor_task_name="train",
simulation_config_file: str = None,
Expand All @@ -114,8 +115,11 @@ def configure_client(
if self.client_config_added:
raise RuntimeError("client config is already added")

# check the validity of max_model_versions if not None
if max_model_versions:
check_positive_int("max_model_versions", max_model_versions)

check_object_type("aggregator_factory", aggregator_factory, AggregatorFactory)
check_positive_int("max_model_versions", max_model_versions)
check_positive_number("update_timeout", update_timeout)
check_str("executor_task_name", executor_task_name)

Expand Down
Loading
Loading