Add PyTorch param validation for model exchange#4392
Add PyTorch param validation for model exchange#4392holgerroth wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
|
@greptileai review |
Greptile SummaryThis PR adds fail-fast parameter-contract validation to NVFlare's PyTorch model exchange and persistence paths. A new Key changes:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant FL as NVFlare FL Round
participant CB as FLCallback / feed_vars
participant V as inspect_model_params
participant PT as pl_module.load_state_dict / model.state_dict
FL->>CB: model.params (incoming global weights)
CB->>PT: pl_module.state_dict() / model.state_dict()
PT-->>CB: local_var_dict
CB->>V: inspect_model_params(local_var_dict, model.params)
V-->>CB: ModelParamMatchReport
alt shape_mismatches present
CB-->>FL: raise RuntimeError(format_shape_mismatch_error)
else matched_keys empty
CB-->>FL: raise RuntimeError(format_zero_match_error)
else unexpected_keys present AND strict=True
CB-->>FL: raise RuntimeError(format_unexpected_keys_error)
else unexpected_keys present AND strict=False
CB->>CB: warn + filter params_to_load to matched_keys only
CB->>PT: load_state_dict(params_to_load, strict=False)
PT-->>CB: IncompatibleKeys(missing_keys, [])
CB-->>FL: log warnings for missing_keys
else all keys matched
CB->>PT: load_state_dict(model.params, strict=strict)
PT-->>CB: success
end
Reviews (4): Last reviewed commit: "Merge branch 'main' into codex/pt-param-..." | Re-trigger Greptile |
|
Addressed the Greptile findings in Updates in this follow-up:
Validation rerun:
|
|
@greptileai review again |
|
@greptileai check the latest PR version |
|
@greptileai review again |
|
/build |
Summary
This change adds fail-fast parameter-contract validation to NVFlare's PyTorch model exchange and persistence paths, so key and shape drift becomes explicit instead of being silently skipped.
The same validation now applies consistently to both native PyTorch and PyTorch Lightning clients.
What changed
inspect_model_paramsandModelParamMatchReportinnvflare.app_opt.pt.utilsfeed_vars()raise on zero-match and shape mismatch, and warn on unexpected incoming keysPTModelPersistenceFormatManager.update()reject persistence updates that would introduce unexpected keys or mismatched shapesload_state_dict()strict=True, reject unexpected keys before callingload_state_dict()strict=False, warn, filter to matched keys, and preserve missing-key diagnostics for partial loadsWhy
This is aimed at failures like wrapper-induced key drift such as
model.*vs. unwrapped names, where the current path can silently ignore loads or save mixed keyspaces. The new validation makes those failures visible immediately and gives enough context to diagnose the mismatch.Before and after
Before this change, an incoming payload like
["model.fc.weight", "model.fc.bias"]sent to a model expecting["fc.weight", "fc.bias"]could be silently skipped by the load path and later merged into persistence as a mixed keyspace.After this change, the same mismatch fails fast with a diagnostic like:
Validation
python3 -m py_compile nvflare/app_opt/pt/utils.py nvflare/app_opt/pt/model_persistence_format_manager.py nvflare/app_opt/lightning/api.py tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.pypytest tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.py -qpytest tests/unit_test/app_opt/lightning tests/unit_test/app_opt/pt -q