project
│ README.md
│ .gitignore
| flower_train.py # Run flower training
| centralized_training.py # Training centralized models
| biased_training.py # Test N matrix prediction
│
└───src # Source code
│ │ flower_client.py # Client Flower class: client optimizers, client info
│ │ flower_strategy.py # Server Flower class: server optimizers, client weighting
| | flower_manager.py # Flower manager for client selection
| | corr.py # N matrix metrics SC,CI,AI
| | matrix_inference.py # N matrix inference
| | weighting_strategy.py # Helpers for imported selection and weighting strategies
│ └───datasets # Dataset loaders, splits
│ └───models # PyTorch models
| └───optimizers # Handling subpopulation shift optimizers
└───conf # Config files following hydra
└───datasets # Dataset files or links
└───checkpoints # Experiment files
- FedProx is a client optimizer method: the
client_opt.subpop_optimizercan be changed fromERMtoProxto enable FedAvgMorFedAvgchanges the server aggregation inserver_opt.optimizer- Client participation can be controlled with weights or selection:
server_opt.participationcan beweightingorselection- If
weightingis used,server_opt.weight_clientssets the specific weighting method. Weighting takes into account even ifselectionis set (for FedPNS), should beserver_opt.weight_clients=sameto disable. - If
selectionis used,server_opt.selection_methodsets the specific selection method. It can berandom,groupweights(original matrix with Oracle ReWeight), ortriplets_stochasticmatrixfor triplets.
- If
- Some methods use information from the client. The
server_opt.client_infotells what data is shared.- If
groupweights,tripletsornovais in the string, these infos will be passed to the server. - If
Npredictedis in the string, numbers are generated with the N matrix prediction algorithm
- If
- The named data splits can be controlled with
dataset_options.split_mode. Most options are insrc.datasets.data_splits.pydataset_options.num_clientsmust be set together with thesplit_mode.
- FedProx:
client_opt.subpop_optimizer=Prox - FedAvgM:
server_opt.optimizer=FedAvgM - FedDiverse:
server_opt.participation=selection,server_opt.selection_method=triplets_stochasticmatrix,triplets_Npredictedinserver_opt.client_info- as weighting:
server_opt.weight_clients=server_post_triplets_stochasticmatrix_noreplacement,server_opt.participation=weighting
- as weighting:
- FedNova:
novainserver_opt.client_info,server_opt.participation=weighting,server_opt.weight_clients=server_post_nova,server_opt.optimizer=FedAvgM - pow-d:
glossinserver_opt.client_info,server_opt.participation=weighting,server_opt.weight_clients=server_post_powd - round robin:
server_opt.participation=selection,server_opt.selection_method=roundrobin - FedPNS:
server_opt.participation=selection,server_opt.selection_method=fedpns,server_opt.weight_clients=server_post_fedpns,glossinserver_opt.client_info - FedPNS w/o weights:
server_opt.participation=selection,server_opt.selection_method=fedpns,server_opt.weight_clients=same,glossinserver_opt.client_info - Oort:
oortinserver_opt.client_info,server_opt.selection_method=oort,server_opt.participation=selection - HCSFed:
compgradinserver_opt.client_info,server_opt.selection_method=hcsfed - FairFed:
groupacciserver_opt.client_info,server_opt.selection_method=fairfedorserver_opt.weight_clients=server_post_FairFed
- Spawrious_GSC:
dataset_options.split_mode=spawrious2,dataset_options.num_clients=24 - Spawrious_GCI:
dataset_options.split_mode=spawrious_GCI,dataset_options.num_clients=24 - Spawrious_GAI:
dataset_options.split_mode=spawrious_GAI_2,dataset_options.num_clients=25 - Waterbirds_dist:
dataset_options.split_mode=waterbirds_dist,dataset_options.num_clients=30 - Spawrious_4:
dataset_options.split_mode=spawrious4,dataset_options.num_clients=25,dataset_options.num_targets=4 - CMINST_GSC:
dataset_options.split_mode=sparwious2,dataset_options.num_clients=24dataset_options.name=CMNIST,dataset_options.input_size=28 - Spawrious_GCI_100:
dataset_options.split_mode=spawrious_GCI_100,dataset_options.num_clients=100,server_opt.num_active_clients=12
Example run: python flower_train.py
python 3.10.11
flwr==1.8.0[simulation]
torch
torchvision
matplotlib
wilds
hydra-core
timm
cvxopt
wandb
git+https://github.com/aengusl/spawrious.git
submitit hydra-submitit-launcher
datasets[vision]
aif360