fix: read epoch count from training config in simulation mode#4404
fix: read epoch count from training config in simulation mode#4404vijaygovindaraja wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
The ETTaskProcessor.process_task() was calling run_training() without passing total_epochs, so simulation mode always trained for exactly 1 epoch regardless of the epoch value in training_config. This meant simulated devices behaved differently from real devices, which read the epoch count from device_training_params. Read epoch from self.training_config and pass it through to run_training(). Also add epoch to the default training_config (default 1 to preserve existing behavior) and update the example job scripts to set epoch in the simulator's training_config to match device_training_params. Addresses item NVIDIA#4 in NVIDIA#3827.
|
@chesterxgchen @nickl1234567 this is the next item from #3827 simulation mode was always running 1 epoch regardless of what's in training_config, so simulated devices didn't match real device behavior when epoch > 1 was configured. Small diff, same file as #4377. |
Greptile SummaryThis PR fixes a simulation parity bug where The implementation is clean and consistent with the existing validation pattern ( Confidence Score: 5/5Safe to merge — the fix is correct, well-validated, and has comprehensive test coverage. All remaining findings are P2 or lower. The primary bug (hardcoded epoch=1 in simulation) is correctly fixed. Both paths that could produce bad epochs are now guarded by No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant C as Caller
participant P as ETTaskProcessor.__init__
participant PT as process_task()
participant RT as run_training()
C->>P: training_config={"epoch": 3}
P->>P: merge with defaults
P->>P: check_positive_int("epoch", 3)
P-->>C: processor ready
C->>PT: task (train)
PT->>PT: total_epochs = training_config.get("epoch", 1)
PT->>RT: run_training(et_model, total_epochs=3)
RT->>RT: check_positive_int("total_epochs", 3)
RT->>RT: DataLoader, raise ValueError if empty
loop epoch in range(total_epochs)
loop batch in dataloader
RT->>RT: forward_backward → capture initial_params
RT->>RT: optimizer.step()
end
end
RT->>RT: calc_params_diff(initial_params, last_params)
RT-->>PT: param diff dict
PT-->>C: dxo_dict
Reviews (5): Last reviewed commit: "fix isort import ordering in test file" | Re-trigger Greptile |
Validate that epoch is a positive integer at init time using NVFlare's check_positive_int, consistent with validation patterns elsewhere in the codebase (e.g. EdgeJob, BaseState). Without this, epoch=0 would silently skip the training loop and then crash in calc_params_diff. Add unit tests covering: - default epoch config (1), override, and preservation of other defaults - rejection of non-positive (0, -1, -10) and non-integer (1.5, str, None) - run_training respects total_epochs for 1 and 3 epochs (forward_backward call counts) - process_task reads epoch from training_config - helper functions clone_params and calc_params_diff
The init-time check_positive_int covers the normal config flow, but run_training is a public method that can be called directly. Add validation there too so callers get a clear ValueError instead of a cryptic AttributeError from calc_params_diff(None, ...).
|
@vijaygovindaraja please resolve the comments once you fix it. |
…ing its logic The test was manually reading epoch from training_config and calling run_training directly, which just proved the test's own code worked. Now it constructs a real TaskResponse with a valid DXO payload, mocks the executorch loader, and calls process_task() end-to-end, verifying that process_task reads epoch from training_config and passes it through to run_training.
|
Addressed both review items:
All 16 tests pass, flake8 clean. |
Addresses item #4 in #3827
Summary
ETTaskProcessor.process_task()callsrun_training()without passingtotal_epochs, so simulation mode always trains for exactly 1 epoch regardless of the epoch value intraining_config. This means simulated devices behave differently from real devices, which read the epoch count fromdevice_training_params.Before:
run_training(et_model)— always 1 epoch, ignoring config.After:
run_training(et_model, total_epochs=self.training_config.get("epoch", 1))— reads epoch from training config, defaults to 1 to preserve existing behavior.Changes
et_task_processor.py— add"epoch"to defaulttraining_config(default: 1), read it inprocess_task()and pass torun_training(). Validate at init time usingcheck_positive_int(consistent withEdgeJob,BaseState, etc.) so thatepoch=0orepoch="three"fails fast instead of silently skipping training or crashing incalc_params_diff.et_job.py— add"epoch": 3to both CIFAR-10 and XOR task processor configs so simulation matchesdevice_training_params.et_task_processor_test.py(new) — 14 unit tests covering:run_trainingrespectstotal_epochsfor 1 and 3 epochs (verified via forward_backward call counts)process_taskreads epoch fromtraining_configclone_paramsandcalc_params_diffTest plan
pytest tests/unit_test/recipe/et_task_processor_test.py)pytest tests/unit_test/recipe/edge_recipe_test.py)flake8passes on all changed files./runtest.sh -sin CI to verify full style checks