Skip to content

fix: read epoch count from training config in simulation mode#4404

Open
vijaygovindaraja wants to merge 5 commits intoNVIDIA:mainfrom
vijaygovindaraja:fix/3827-epoch-config
Open

fix: read epoch count from training config in simulation mode#4404
vijaygovindaraja wants to merge 5 commits intoNVIDIA:mainfrom
vijaygovindaraja:fix/3827-epoch-config

Conversation

@vijaygovindaraja
Copy link
Copy Markdown
Contributor

@vijaygovindaraja vijaygovindaraja commented Apr 4, 2026

Addresses item #4 in #3827

Summary

ETTaskProcessor.process_task() calls run_training() without passing total_epochs, so simulation mode always trains for exactly 1 epoch regardless of the epoch value in training_config. This means simulated devices behave differently from real devices, which read the epoch count from device_training_params.

Before: run_training(et_model) — always 1 epoch, ignoring config.

After: run_training(et_model, total_epochs=self.training_config.get("epoch", 1)) — reads epoch from training config, defaults to 1 to preserve existing behavior.

Changes

  1. et_task_processor.py — add "epoch" to default training_config (default: 1), read it in process_task() and pass to run_training(). Validate at init time using check_positive_int (consistent with EdgeJob, BaseState, etc.) so that epoch=0 or epoch="three" fails fast instead of silently skipping training or crashing in calc_params_diff.

  2. et_job.py — add "epoch": 3 to both CIFAR-10 and XOR task processor configs so simulation matches device_training_params.

  3. et_task_processor_test.py (new) — 14 unit tests covering:

    • Default epoch config, override, preservation of other defaults
    • Rejection of non-positive values (0, -1, -10) and non-integer types (1.5, str, None)
    • run_training respects total_epochs for 1 and 3 epochs (verified via forward_backward call counts)
    • process_task reads epoch from training_config
    • Helper functions clone_params and calc_params_diff

Test plan

  • 14 unit tests pass locally (pytest tests/unit_test/recipe/et_task_processor_test.py)
  • Existing edge recipe tests unaffected (pytest tests/unit_test/recipe/edge_recipe_test.py)
  • flake8 passes on all changed files
  • Run ./runtest.sh -s in CI to verify full style checks
  • Verify no regressions in existing edge training tests

The ETTaskProcessor.process_task() was calling run_training() without
passing total_epochs, so simulation mode always trained for exactly 1
epoch regardless of the epoch value in training_config. This meant
simulated devices behaved differently from real devices, which read
the epoch count from device_training_params.

Read epoch from self.training_config and pass it through to
run_training(). Also add epoch to the default training_config (default
1 to preserve existing behavior) and update the example job scripts to
set epoch in the simulator's training_config to match
device_training_params.

Addresses item NVIDIA#4 in NVIDIA#3827.
@vijaygovindaraja
Copy link
Copy Markdown
Contributor Author

vijaygovindaraja commented Apr 4, 2026

@chesterxgchen @nickl1234567 this is the next item from #3827 simulation mode was always running 1 epoch regardless of what's in training_config, so simulated devices didn't match real device behavior when epoch > 1 was configured.

Small diff, same file as #4377.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 4, 2026

Greptile Summary

This PR fixes a simulation parity bug where ETTaskProcessor.process_task() hardcoded total_epochs=1, causing simulated devices to always train for one epoch while real devices respected the epoch value in device_training_params. The fix reads epoch from training_config, validates it at init-time and inside run_training, updates the example job configs to match, and adds 14 unit tests.

The implementation is clean and consistent with the existing validation pattern (check_positive_int) used elsewhere in the codebase.

Confidence Score: 5/5

Safe to merge — the fix is correct, well-validated, and has comprehensive test coverage.

All remaining findings are P2 or lower. The primary bug (hardcoded epoch=1 in simulation) is correctly fixed. Both paths that could produce bad epochs are now guarded by check_positive_int. The example configs are updated to match. The 14-test suite exercises defaults, validation, epoch counts, and the process_task→run_training handoff. No blocking issues remain from the previous review thread.

No files require special attention.

Important Files Changed

Filename Overview
nvflare/edge/simulation/et_task_processor.py Core fix: adds epoch default to training_config, validates with check_positive_int at init and inside run_training, and passes total_epochs from config in process_task. Empty-dataloader guard and all validation paths are correct.
examples/advanced/edge/jobs/et_job.py Adds "epoch": 3 to both CIFAR-10 and XOR task_processor configs, aligning simulation with the existing device_training_params={"epoch": 3, ...} on the same recipe call.
tests/unit_test/recipe/et_task_processor_test.py New test file with 14 tests covering default/override config, positive-int validation, run_training epoch counts, process_task epoch forwarding, and helper functions clone_params/calc_params_diff.

Sequence Diagram

sequenceDiagram
    participant C as Caller
    participant P as ETTaskProcessor.__init__
    participant PT as process_task()
    participant RT as run_training()

    C->>P: training_config={"epoch": 3}
    P->>P: merge with defaults
    P->>P: check_positive_int("epoch", 3)
    P-->>C: processor ready

    C->>PT: task (train)
    PT->>PT: total_epochs = training_config.get("epoch", 1)
    PT->>RT: run_training(et_model, total_epochs=3)
    RT->>RT: check_positive_int("total_epochs", 3)
    RT->>RT: DataLoader, raise ValueError if empty
    loop epoch in range(total_epochs)
        loop batch in dataloader
            RT->>RT: forward_backward → capture initial_params
            RT->>RT: optimizer.step()
        end
    end
    RT->>RT: calc_params_diff(initial_params, last_params)
    RT-->>PT: param diff dict
    PT-->>C: dxo_dict
Loading

Reviews (5): Last reviewed commit: "fix isort import ordering in test file" | Re-trigger Greptile

Validate that epoch is a positive integer at init time using NVFlare's
check_positive_int, consistent with validation patterns elsewhere in
the codebase (e.g. EdgeJob, BaseState). Without this, epoch=0 would
silently skip the training loop and then crash in calc_params_diff.

Add unit tests covering:
- default epoch config (1), override, and preservation of other defaults
- rejection of non-positive (0, -1, -10) and non-integer (1.5, str, None)
- run_training respects total_epochs for 1 and 3 epochs (forward_backward
  call counts)
- process_task reads epoch from training_config
- helper functions clone_params and calc_params_diff
The init-time check_positive_int covers the normal config flow, but
run_training is a public method that can be called directly. Add
validation there too so callers get a clear ValueError instead of a
cryptic AttributeError from calc_params_diff(None, ...).
@chesterxgchen
Copy link
Copy Markdown
Collaborator

@vijaygovindaraja please resolve the comments once you fix it.

…ing its logic

The test was manually reading epoch from training_config and calling
run_training directly, which just proved the test's own code worked.
Now it constructs a real TaskResponse with a valid DXO payload, mocks
the executorch loader, and calls process_task() end-to-end, verifying
that process_task reads epoch from training_config and passes it
through to run_training.
@vijaygovindaraja
Copy link
Copy Markdown
Contributor Author

Addressed both review items:

  1. P1 (tautological test) — rewrote test_process_task_passes_epoch_from_config to construct a real TaskResponse with a valid DXO payload, mock the executorch loader, and call process_task() end-to-end. The spy on run_training now verifies the actual code path at lines 259-260.

  2. P1 (non-positive epoch crash) — already addressed in previous push: check_positive_int guards both init-time config and run_training() directly.

All 16 tests pass, flake8 clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants