Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Flower Baselines is a collection of community-contributed projects that reproduc
- [FedPara](https://github.com/adap/flower/tree/main/baselines/fedpara)
- [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist)
- [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization)
- [Floco](https://github.com/adap/flower/tree/main/baselines/floco)

Please refer to the [Flower Baselines Documentation](https://flower.ai/docs/baselines/) for a detailed categorization of baselines and for additional info including:
* [How to use Flower Baselines](https://flower.ai/docs/baselines/how-to-use-baselines.html)
Expand Down
15 changes: 10 additions & 5 deletions baselines/floco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,29 @@ pip install .

## Running the Experiments
In order to run using the default settings, simply run:

### Through your environment
```bash
flwr run .
flwr run . --federation-config "options.num-supernodes=100"
```

This will run Floco on the Dirichlet(0.5) split. If you want to test other splits, you can specify which config you want to use. For example, if you would like to test Floco<sup>+</sup> on the CIFAR-10 Five Fold split from the original paper, simply run:
```bash
flwr run . --run-config conf/cifar10_fold_floco_p.toml
flwr run . --federation-config "options.num-supernodes=100" --run-config conf/cifar10_fold_floco_p.toml
```
If you want to run the benchmark shown int the Figures below, simply run:

If you want to run the benchmark shown in the Figures below, simply run:
```bash
bash ./run.sh
```
This will run FedAvg, Floco and Floco<sup>+</sup> on the CIFAR-10 Five-Fold and Dirichlet(0.5) split.
This will run FedAvg, Floco and Floco<sup>+</sup> on the CIFAR-10 Five-Fold and Dirichlet(0.5) split with `options.num-supernodes=100`.

## Expected Results
In order to generate the result plots shown below, run:
```
python floco/plot_results.py
```

### CIFAR-10 Five-Fold split
<img src="_static/CIFAR10_Fold.png" width="600"/>

Expand All @@ -109,4 +114,4 @@ python floco/plot_results.py
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems (NeurIPS'24)},
year={2024},
url={https://openreview.net/forum?id=JL2eMCfDW8}
}
}
Binary file modified baselines/floco/_static/CIFAR10_Dirichlet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified baselines/floco/_static/CIFAR10_Fold.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
257 changes: 113 additions & 144 deletions baselines/floco/floco/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,153 +3,122 @@
import copy

import torch

from flwr.client import ClientApp, NumPyClient
from flwr.common import ArrayRecord, Context, bytes_to_ndarray
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp
from flwr.common import bytes_to_ndarray

from .dataset import get_federated_dataloaders
from .model import Net, SimplexModel, get_weights, set_weights, test, train


class FlowerClient(NumPyClient):
"""A class defining the client."""

def __init__(
self,
partition_id,
global_model,
pers_model,
pers_lamda,
trainloader,
valloader,
local_epochs,
device,
context,
):
self.partition_id = partition_id
self.global_model = global_model
self.pers_model = pers_model
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs
self.device = device

# Floco+ params
self.pers_lamda = pers_lamda
if self.pers_lamda != 0:
self.client_state = context.state
if "pers_parameters" not in self.client_state:
self.client_state["pers_parameters"] = ArrayRecord()

def fit(self, parameters, config):
"""Train model using this client's data."""
set_weights(self.global_model, parameters)
reg_parameters = copy.deepcopy(list(self.global_model.parameters()))
train_loss = self._train(self.global_model, config)

if self.pers_lamda != 0:
array_record = self.client_state["pers_parameters"]
if len(array_record) > 0:
self.pers_model.load_state_dict(
self.client_state["pers_parameters"].to_torch_state_dict()
)
self._train(self.pers_model, config, reg_parameters, self.pers_lamda)
self.client_state["pers_parameters"] = ArrayRecord(
self.pers_model.state_dict()
)

return (
get_weights(self.global_model),
len(self.trainloader.dataset),
{"train_loss": train_loss},
)
from .model import SimplexModel, create_model, test, train

def evaluate(self, parameters, config):
"""Evaluate model using this client's data."""
set_weights(self.global_model, parameters)
model = self.global_model
if self.pers_lamda != 0:
array_record = self.client_state["pers_parameters"]
if len(array_record) > 0:
model = self.pers_model
self.pers_model.load_state_dict(
self.client_state["pers_parameters"].to_torch_state_dict()
)
else:
model = self.global_model
self._set_simplex_params(model, config, training=False)
loss, accuracy = test(model, self.valloader, self.device)
return loss, len(self.valloader.dataset), {"loss": loss, "accuracy": accuracy}

def get_properties(self, config):
"""Return the properties of this client."""
return {"partition-id": self.partition_id}

def _train(self, model, config, reg_parameters=None, lamda=0):
"""Set simplex parameters and train."""
self._set_simplex_params(model, config, training=True)
train_loss = train(
model,
self.trainloader,
self.local_epochs,
self.device,
reg_parameters,
lamda,
)
return train_loss

def _set_simplex_params(self, model, config, training=None):
"""Set simplex parameters, i.e. projected point and sampling radius."""
model.training = bool(training)
if all(key in config for key in ["center", "radius"]):
model.subregion_parameters = (
bytes_to_ndarray(config["center"]),
config["radius"],
)


def client_fn(context: Context):
"""Construct a Client that will be run in a ClientApp."""
# Load model and data
device = torch.device(
"cuda:0"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
seed = int(context.run_config["seed"])
endpoints = int(context.run_config["endpoints"])
DEVICE = torch.device(
"cuda:0"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

app = ClientApp()


@app.train()
def train_fn(msg: Message, context: Context) -> Message:
"""Train the model on local data."""
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
local_epochs = int(context.run_config["local-epochs"])
pers_model = None
pers_lamda = 0
pers_lamda = int(context.run_config["pers_lamda"])
algorithm = str(context.run_config["algorithm"])

# Load model from received arrays
global_model = create_model(context).to(DEVICE)
arrays = msg.content.array_records["arrays"]
global_model.load_state_dict(arrays.to_torch_state_dict())

# Load data
trainloader, _ = get_federated_dataloaders(partition_id, num_partitions, context)

# Set simplex params if present in config
config = msg.content.config_records.get("config", {})
if isinstance(global_model, SimplexModel):
_apply_simplex_config(global_model, config, training=True)

# Floco+ personalization: save reg params before training global model
use_pers = pers_lamda != 0 and algorithm == "Floco"
if use_pers:
reg_parameters = copy.deepcopy(list(global_model.parameters()))

# Train global model
train_loss = train(global_model, trainloader, local_epochs, DEVICE)

# Floco+ personalization: train personalized model
if use_pers:
endpoints = int(context.run_config["endpoints"])
pers_model = SimplexModel(endpoints=endpoints).to(DEVICE)
if "pers_parameters" not in context.state:
context.state["pers_parameters"] = ArrayRecord()
pers_record = context.state["pers_parameters"]
if len(pers_record) > 0:
pers_model.load_state_dict(pers_record.to_torch_state_dict())
_apply_simplex_config(pers_model, config, training=True)
train(pers_model, trainloader, local_epochs, DEVICE, reg_parameters, pers_lamda)
context.state["pers_parameters"] = ArrayRecord(pers_model.state_dict())

# Construct reply
model_record = ArrayRecord(global_model.state_dict())
metrics = MetricRecord({
"train_loss": train_loss,
"num-examples": len(trainloader.dataset),
})
content = RecordDict({"arrays": model_record, "metrics": metrics})
return Message(content=content, reply_to=msg)


@app.evaluate()
def evaluate_fn(msg: Message, context: Context) -> Message:
"""Evaluate the model on local data."""
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])

if context.run_config["algorithm"] == "FedAvg":
global_model = Net(seed=seed).to(device)
elif context.run_config["algorithm"] == "Floco":
global_model = SimplexModel(endpoints=endpoints, seed=seed).to(device)
pers_lamda = int(context.run_config["pers_lamda"])
if pers_lamda != 0:
pers_model = SimplexModel(endpoints=endpoints, seed=seed).to(device)
else:
raise ValueError("Algorithm not implemented")
trainloader, valloader = get_federated_dataloaders(
partition_id, num_partitions, context
)

# Return Client instance
return FlowerClient(
partition_id,
global_model,
pers_model,
pers_lamda,
trainloader,
valloader,
local_epochs,
device,
context,
).to_client()


# Flower ClientApp
app = ClientApp(client_fn)
pers_lamda = int(context.run_config["pers_lamda"])
algorithm = str(context.run_config["algorithm"])

# Load model from received arrays
model = create_model(context).to(DEVICE)
arrays = msg.content.array_records["arrays"]
model.load_state_dict(arrays.to_torch_state_dict())

# Load data
_, valloader = get_federated_dataloaders(partition_id, num_partitions, context)

# Floco+ personalization: use personalized model if available
if pers_lamda != 0 and algorithm == "Floco":
if "pers_parameters" in context.state:
pers_record = context.state["pers_parameters"]
if len(pers_record) > 0:
endpoints = int(context.run_config["endpoints"])
model = SimplexModel(endpoints=endpoints).to(DEVICE)
model.load_state_dict(pers_record.to_torch_state_dict())

# Set simplex params for evaluation
config = msg.content.config_records.get("config", {})
if isinstance(model, SimplexModel):
_apply_simplex_config(model, config, training=False)

loss, accuracy = test(model, valloader, DEVICE)

# Construct reply
metrics = MetricRecord({
"loss": loss,
"accuracy": accuracy,
"num-examples": len(valloader.dataset),
})
content = RecordDict({"metrics": metrics})
return Message(content=content, reply_to=msg)


def _apply_simplex_config(model, config, training):
"""Apply simplex subregion parameters from a message config to a model."""
model.training = training
if "center" in config and "radius" in config:
model.subregion_parameters = (
bytes_to_ndarray(config["center"]),
config["radius"],
)
9 changes: 3 additions & 6 deletions baselines/floco/floco/dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""floco: A Flower Baseline."""

from typing import Tuple

from datasets import load_dataset
from flwr.app import Context
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import InnerDirichletPartitioner
from torch.utils.data import DataLoader
from torchvision import transforms

from flwr.common import Context

from datasets import load_dataset
from .partitioners import FoldPartitioner

# pylint: disable=C0103, W0603
Expand All @@ -35,7 +32,7 @@ def get_testloader(dataset: str) -> DataLoader:

def get_federated_dataloaders(
partition_id: int, num_partitions: int, context: Context
) -> Tuple[DataLoader, DataLoader]:
) -> tuple[DataLoader, DataLoader]:
"""Create dataloaders for a specified dataset and partition.

partition_id : int
Expand Down
Loading
Loading