Skip to content

Issues and Improvements of edge training #3827

@nickl1234567

Description

@nickl1234567

I am currently trying out the edge examples. While it is an amazing piece of code, I found a few issues that I want to keep track of here. For most of these issues, I have already mentioned the appropriate fixes. Maybe I will contribute some of those fixes, but I cannot promise it at the moment. Also, feel free to discuss the proposed fixes!

After preparing this issue, I want to mention again that I am very grateful for Nvidia Flare, even though this issue might not reflect that 😅

1. Argument missing in readme.

I believe the provision command here misses the -s argument, as you will later run the ./start_all.sh which is only generated using this argument.

nvflare provision -p project.yml

2. Wrong port in readme.

The readme says that the Port in the edge device must be set to one of those specified in lcp_ma.json, but I believe it must be set to the port specified in ./start_rp.sh. Also, the proxy will show the current IP and Port, which is very convenient.

You need to configure the server PORT to be the PORT shown in lcp_map.json (for example: 9003).

3. The global model is not saved by default.

I hacked it myself by extending the PTFileModelPersistor and replacing it here

persistor_id = job.to_server(PTFileModelPersistor(model=self.model), id="persistor")

My model persistor looks like this:

class MyPTFileModelPersistor(PTFileModelPersistor):
    def handle_event(self, event, fl_ctx):
        super().handle_event(event, fl_ctx)
        if event == AppEventType.GLOBAL_WEIGHTS_UPDATED:
            ml = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
            if ml:
                self._get_persistence_manager(fl_ctx).update(ml)
            self.save_model_file(self._ckpt_save_path)

4. The number of training epochs is not the one specified.

In simulation mode, the number of epochs is always 1, which is not what I expected. I recommend modifying the ETTaskProcessor here

diff_dict = self.run_training(et_model)

to

diff_dict = self.run_training(et_model, total_epochs=self.training_config.get("epoch", 1))

Also, for clarification, the number of epochs in the scripts should probably be set for both the simulator and the devices, in these locations:

"batch_size": batch_size,

"batch_size": batch_size,

device_training_params={"epoch": 3, "lr": 0.0001, "batch_size": batch_size},

5. The Optimizer is created in each iteration.

Creating the optimizer for each batch is inefficient, and clears additional information, such as the momentum of parameters. Hence, it should only be created once. I suggest changing

optimizer = get_sgd_optimizer(

to

if optimizer is None:
  optimizer = get_sgd_optimizer(
  	et_model.named_parameters(),
  	self.training_config["learning_rate"],
  	self.training_config["momentum"],
  	self.training_config["weight_decay"],
  	self.training_config["dampening"],
  	self.training_config["nesterov"],
  )

with optimizer=None at the beginning of the method.

The same issue exists for the Android app (and potentially the iOS app), creating the optimizer before each step separately:

val sgd = SGD.create(parameters, learningRate.toDouble(), momentum.toDouble(), SGD_WEIGHT_DECAY, SGD_WEIGHT_DECAY, SGD_NESTEROV)

6. The training parameters are not used in the Android training.

When creating the SGD, most parameters are not loaded from the config. Ideally, there would be a safeguard somewhere to prevent users from setting parameters that cannot be used.

val sgd = SGD.create(parameters, learningRate.toDouble(), momentum.toDouble(), SGD_WEIGHT_DECAY, SGD_WEIGHT_DECAY, SGD_NESTEROV)

7. Hyperparameters do not even reach the Android client.

So, this one is a bit more confusing. First, I thought the parameters specified in the job script do not even reach the app, because I thought the parameters would be passed here next to the model DXO.

val trainingConfig = extractTrainingConfig(taskData, ctx)

However, it seems there are also parameters that are only transferred once, which do contain the parameters specified; essentially, they are passed through here

val trainingConfig = TrainingConfig.fromMap(meta)

and stored here in the ETTrainer
private val meta: Map<String, Any>,

but never used!

8. Hardcoded tensor sizes.

While this is fine for a simple example, this should be updated to use the sizes that have been specified in the job

private fun createInputTensor(inputData: FloatArray, method: String, batchSize: Int): Tensor {

At the moment, I do not have a convenient fix for that.

9. The ExecuTorch magic byte is probably not correct.

The ET magic bytes specified here

private const val EXECUTORCH_HEADER_PREFIX = "PAAAAEVU"

are not correct for me. I assume this might be caused by different ET versions.

10. Semantics of supported jobs are misleading.

As far as I understood, the current app will ask for a job of a specific name, so if I start the xor_et job on the server and start training on the client with both CIFAR-10 and XOR jobs supported, the app will only ask for jobs called cifar10_et. So it will not get the XOR job. Even worse, the XOR job does not seem to be supported on the app. I believe the reason is that here

val method = args?.get("method") as? String ?: "cnn"

and here
val method = args?.get("method") as? String ?: "cnn"

The "cnn" method is used as the default if the "method" parameter is not specified in the training config. However, as it is never specified in the job script, this will always be "cnn".
I did not investigate this further, so this might not actually be the real reason.

11. The formatting of the training logs looks a bit off.

This applies to the training logs generated in the app. Kotlin will keep using the tabulators from the code in the logs generated, which is inconvenient to read. Instead, the text can simply be replaced by

            val summary = buildString{
                appendLine("========================================")
                appendLine("TRAINING COMPLETED")
                appendLine("========================================")
                appendLine()
                appendLine("Method: $method")
                appendLine("Epochs: $epochs")
                appendLine("Batch Size: $batchSize")
                appendLine("Learning Rate: $learningRate")
                appendLine("Total Steps: $totalSteps")
                appendLine("Average Loss: ${totalLoss / totalSteps}")
                appendLine()
                appendLine("Tensor Differences:")
                tensorDiff.entries.forEach { (key, value) ->
                    appendLine("  $key: ${value["data"]?.let { data -> if (data is List<*>) "size=${data.size}" else "unknown" } ?: "unknown"}")
                }
                appendLine()
                appendLine("Artifacts Location: ${artifactsDir.absolutePath}")
                appendLine()
                appendLine("========================================")
            }

Here


and here

12. fetchTask might result in a stack overflow.

The fetchTask will call itself here


In my case, I am using synchronous FedAvg, and as the dataset on the phone is much smaller than those of the simulators, the app will wait for the other clients til a new task becomes available.
For some reason, the retryWait is only 30ms, so the stack will overflow rapidly, crashing the app.
The fix is rather simple by splitting the logic into a private fetchTaskOnce, which will do the fetching itself, and fetchTask, which will call fetchTaskOnce in a loop and only breaks the loop once the result is not TaskStatus.RETRY

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions