-
Notifications
You must be signed in to change notification settings - Fork 239
Description
I am currently trying out the edge examples. While it is an amazing piece of code, I found a few issues that I want to keep track of here. For most of these issues, I have already mentioned the appropriate fixes. Maybe I will contribute some of those fixes, but I cannot promise it at the moment. Also, feel free to discuss the proposed fixes!
After preparing this issue, I want to mention again that I am very grateful for Nvidia Flare, even though this issue might not reflect that 😅
1. Argument missing in readme.
I believe the provision command here misses the -s argument, as you will later run the ./start_all.sh which is only generated using this argument.
NVFlare/examples/advanced/edge/README.md
Line 32 in ee8a5fe
| nvflare provision -p project.yml |
2. Wrong port in readme.
The readme says that the Port in the edge device must be set to one of those specified in lcp_ma.json, but I believe it must be set to the port specified in ./start_rp.sh. Also, the proxy will show the current IP and Port, which is very convenient.
NVFlare/examples/advanced/edge/README.md
Line 132 in ee8a5fe
| You need to configure the server PORT to be the PORT shown in lcp_map.json (for example: 9003). |
3. The global model is not saved by default.
I hacked it myself by extending the PTFileModelPersistor and replacing it here
| persistor_id = job.to_server(PTFileModelPersistor(model=self.model), id="persistor") |
My model persistor looks like this:
class MyPTFileModelPersistor(PTFileModelPersistor):
def handle_event(self, event, fl_ctx):
super().handle_event(event, fl_ctx)
if event == AppEventType.GLOBAL_WEIGHTS_UPDATED:
ml = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if ml:
self._get_persistence_manager(fl_ctx).update(ml)
self.save_model_file(self._ckpt_save_path)4. The number of training epochs is not the one specified.
In simulation mode, the number of epochs is always 1, which is not what I expected. I recommend modifying the ETTaskProcessor here
| diff_dict = self.run_training(et_model) |
to
diff_dict = self.run_training(et_model, total_epochs=self.training_config.get("epoch", 1))Also, for clarification, the number of epochs in the scripts should probably be set for both the simulator and the devices, in these locations:
| "batch_size": batch_size, |
| "batch_size": batch_size, |
NVFlare/examples/advanced/edge/jobs/et_job.py
Line 108 in ee8a5fe
| device_training_params={"epoch": 3, "lr": 0.0001, "batch_size": batch_size}, |
5. The Optimizer is created in each iteration.
Creating the optimizer for each batch is inefficient, and clears additional information, such as the momentum of parameters. Hence, it should only be created once. I suggest changing
| optimizer = get_sgd_optimizer( |
to
if optimizer is None:
optimizer = get_sgd_optimizer(
et_model.named_parameters(),
self.training_config["learning_rate"],
self.training_config["momentum"],
self.training_config["weight_decay"],
self.training_config["dampening"],
self.training_config["nesterov"],
)with optimizer=None at the beginning of the method.
The same issue exists for the Android app (and potentially the iOS app), creating the optimizer before each step separately:
| val sgd = SGD.create(parameters, learningRate.toDouble(), momentum.toDouble(), SGD_WEIGHT_DECAY, SGD_WEIGHT_DECAY, SGD_NESTEROV) |
6. The training parameters are not used in the Android training.
When creating the SGD, most parameters are not loaded from the config. Ideally, there would be a safeguard somewhere to prevent users from setting parameters that cannot be used.
| val sgd = SGD.create(parameters, learningRate.toDouble(), momentum.toDouble(), SGD_WEIGHT_DECAY, SGD_WEIGHT_DECAY, SGD_NESTEROV) |
7. Hyperparameters do not even reach the Android client.
So, this one is a bit more confusing. First, I thought the parameters specified in the job script do not even reach the app, because I thought the parameters would be passed here next to the model DXO.
| val trainingConfig = extractTrainingConfig(taskData, ctx) |
However, it seems there are also parameters that are only transferred once, which do contain the parameters specified; essentially, they are passed through here
| val trainingConfig = TrainingConfig.fromMap(meta) |
and stored here in the ETTrainer
| private val meta: Map<String, Any>, |
but never used!
8. Hardcoded tensor sizes.
While this is fine for a simple example, this should be updated to use the sizes that have been specified in the job
| private fun createInputTensor(inputData: FloatArray, method: String, batchSize: Int): Tensor { |
At the moment, I do not have a convenient fix for that.
9. The ExecuTorch magic byte is probably not correct.
The ET magic bytes specified here
| private const val EXECUTORCH_HEADER_PREFIX = "PAAAAEVU" |
are not correct for me. I assume this might be caused by different ET versions.
10. Semantics of supported jobs are misleading.
As far as I understood, the current app will ask for a job of a specific name, so if I start the xor_et job on the server and start training on the client with both CIFAR-10 and XOR jobs supported, the app will only ask for jobs called cifar10_et. So it will not get the XOR job. Even worse, the XOR job does not seem to be supported on the app. I believe the reason is that here
| val method = args?.get("method") as? String ?: "cnn" |
and here
| val method = args?.get("method") as? String ?: "cnn" |
The "cnn" method is used as the default if the "method" parameter is not specified in the training config. However, as it is never specified in the job script, this will always be "cnn".
I did not investigate this further, so this might not actually be the real reason.
11. The formatting of the training logs looks a bit off.
This applies to the training logs generated in the app. Kotlin will keep using the tabulators from the code in the logs generated, which is inconvenient to read. Instead, the text can simply be replaced by
val summary = buildString{
appendLine("========================================")
appendLine("TRAINING COMPLETED")
appendLine("========================================")
appendLine()
appendLine("Method: $method")
appendLine("Epochs: $epochs")
appendLine("Batch Size: $batchSize")
appendLine("Learning Rate: $learningRate")
appendLine("Total Steps: $totalSteps")
appendLine("Average Loss: ${totalLoss / totalSteps}")
appendLine()
appendLine("Tensor Differences:")
tensorDiff.entries.forEach { (key, value) ->
appendLine(" $key: ${value["data"]?.let { data -> if (data is List<*>) "size=${data.size}" else "unknown" } ?: "unknown"}")
}
appendLine()
appendLine("Artifacts Location: ${artifactsDir.absolutePath}")
appendLine()
appendLine("========================================")
}Here
| val summary = """ |
and here
| summaryLog.writeText(""" |
12. fetchTask might result in a stack overflow.
The fetchTask will call itself here
| fetchJob(jobName) |
In my case, I am using synchronous FedAvg, and as the dataset on the phone is much smaller than those of the simulators, the app will wait for the other clients til a new task becomes available.
For some reason, the retryWait is only 30ms, so the stack will overflow rapidly, crashing the app.
The fix is rather simple by splitting the logic into a private
fetchTaskOnce, which will do the fetching itself, and fetchTask, which will call fetchTaskOnce in a loop and only breaks the loop once the result is not TaskStatus.RETRY