Skip to content

Add PyTorch param validation for model exchange#4392

Open
holgerroth wants to merge 5 commits intoNVIDIA:mainfrom
holgerroth:codex/pt-param-validation
Open

Add PyTorch param validation for model exchange#4392
holgerroth wants to merge 5 commits intoNVIDIA:mainfrom
holgerroth:codex/pt-param-validation

Conversation

@holgerroth
Copy link
Copy Markdown
Collaborator

@holgerroth holgerroth commented Apr 2, 2026

Summary

This change adds fail-fast parameter-contract validation to NVFlare's PyTorch model exchange and persistence paths, so key and shape drift becomes explicit instead of being silently skipped.

The same validation now applies consistently to both native PyTorch and PyTorch Lightning clients.

What changed

  • add inspect_model_params and ModelParamMatchReport in nvflare.app_opt.pt.utils
  • make feed_vars() raise on zero-match and shape mismatch, and warn on unexpected incoming keys
  • make PTModelPersistenceFormatManager.update() reject persistence updates that would introduce unexpected keys or mismatched shapes
  • reuse the same validation in the PyTorch Lightning callback before load_state_dict()
  • in Lightning strict=True, reject unexpected keys before calling load_state_dict()
  • in Lightning strict=False, warn, filter to matched keys, and preserve missing-key diagnostics for partial loads
  • keep the new code and tests compatible with Python 3.8
  • add focused unit coverage for native PyTorch and Lightning mismatch paths

Why

This is aimed at failures like wrapper-induced key drift such as model.* vs. unwrapped names, where the current path can silently ignore loads or save mixed keyspaces. The new validation makes those failures visible immediately and gives enough context to diagnose the mismatch.

Before and after

Before this change, an incoming payload like ["model.fc.weight", "model.fc.bias"] sent to a model expecting ["fc.weight", "fc.bias"] could be silently skipped by the load path and later merged into persistence as a mixed keyspace.

After this change, the same mismatch fails fast with a diagnostic like:

None of the 2 incoming model parameter(s) matched the local model's 2 parameter(s). Incoming keys: 2 sample=['model.fc.bias', 'model.fc.weight']. Local keys: 2 sample=['fc.bias', 'fc.weight']. Unexpected keys sample: ['model.fc.bias', 'model.fc.weight']. Hint: stripping common prefix 'model.' would match 2/2 incoming key(s).

Validation

  • python3 -m py_compile nvflare/app_opt/pt/utils.py nvflare/app_opt/pt/model_persistence_format_manager.py nvflare/app_opt/lightning/api.py tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.py
  • pytest tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.py -q
  • pytest tests/unit_test/app_opt/lightning tests/unit_test/app_opt/pt -q

@holgerroth holgerroth changed the title [codex] Add PyTorch param validation for model exchange Add PyTorch param validation for model exchange Apr 2, 2026
@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai review

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 2, 2026

Greptile Summary

This PR adds fail-fast parameter-contract validation to NVFlare's PyTorch model exchange and persistence paths. A new inspect_model_params utility produces a ModelParamMatchReport describing key matches, shape mismatches, and unexpected keys (with a prefix-strip hint for common wrapper patterns). The same validation is wired into feed_vars, PTModelPersistenceFormatManager.update(), and the PyTorch Lightning FLCallback, with strict/non-strict branching handled consistently. Prior review concerns (contradictory "Ignoring" wording in raised exceptions, PEP 585 annotations on Python 3.8, misleading strict-mode warning, parenthesised with syntax) are all resolved in the latest commits.

Key changes:

  • nvflare/app_opt/pt/utils.py: new ModelParamMatchReport dataclass + inspect_model_params; feed_vars upgraded to use pre-validation
  • nvflare/app_opt/pt/model_persistence_format_manager.py: update() rejects unexpected keys and shape mismatches before writing
  • nvflare/app_opt/lightning/api.py: _receive_and_update_model splits strict/non-strict paths, filters unexpected keys before load_state_dict, and preserves missing-key diagnostics
  • Two new test files with focused coverage for all mismatch paths
  • Remaining P2 observation: the validation RuntimeErrors are still raised inside the existing try/except in _receive_and_update_model, so they get caught and re-logged/re-wrapped as "Failed to load state dict / model state dict", which is a misleading error context for pre-validation failures (not a correctness issue; tests pass)

Confidence Score: 5/5

  • Safe to merge — all previously raised P0/P1 concerns are resolved; remaining findings are P2 style/diagnostic improvements that do not affect correctness.
  • All four prior review issues (contradictory error wording, Python 3.8 annotation incompatibility, misleading strict-mode warning, parenthesised with syntax) have been resolved in subsequent commits. The only remaining finding is a P2 UX concern: validation errors inside _receive_and_update_model are caught and re-wrapped under "Failed to load model state dict:", which is a misleading log/exception prefix but does not affect correctness or data integrity. Tests pass and cover the key paths.
  • nvflare/app_opt/lightning/api.py — the try/except block scope wraps pre-validation errors under a PyTorch load-failure message; minor but worth a follow-up cleanup.

Important Files Changed

Filename Overview
nvflare/app_opt/pt/utils.py Adds ModelParamMatchReport, ParamShapeMismatch, inspect_model_params, and helpers for fail-fast param validation; feed_vars updated to use the new validation path; from __future__ import annotations added for Python 3.8 compat.
nvflare/app_opt/lightning/api.py Integrates inspect_model_params before load_state_dict; strict-mode raises on unexpected keys, non-strict warns and filters; missing-key diagnostics preserved — but the validation RuntimeErrors are inside the existing broad except block, causing them to be re-wrapped and spuriously logged as state-dict load failures.
nvflare/app_opt/pt/model_persistence_format_manager.py Adds pre-update validation in PTModelPersistenceFormatManager.update(): shape mismatch, zero-match, and unexpected-key checks all raise ValueError before any write; uses dedicated format_unexpected_keys_error() for clear rejection messages.
tests/unit_test/app_opt/pt/pt_param_validation_test.py New focused unit tests covering feed_vars assignment, zero-match with prefix hint, shape mismatch, unexpected-key warning, and persistence-manager partial/reject paths; Python 3.8 compatible.
tests/unit_test/app_opt/lightning/api_test.py New Lightning callback tests covering zero-match prefix hint, shape mismatch, unexpected-key warning (non-strict), unexpected-key rejection (strict), and partial-load missing-key diagnostics; fixed from with (A, B): to nested with for Python 3.8 compat.

Sequence Diagram

sequenceDiagram
    participant FL as NVFlare FL Round
    participant CB as FLCallback / feed_vars
    participant V as inspect_model_params
    participant PT as pl_module.load_state_dict / model.state_dict

    FL->>CB: model.params (incoming global weights)
    CB->>PT: pl_module.state_dict() / model.state_dict()
    PT-->>CB: local_var_dict
    CB->>V: inspect_model_params(local_var_dict, model.params)
    V-->>CB: ModelParamMatchReport

    alt shape_mismatches present
        CB-->>FL: raise RuntimeError(format_shape_mismatch_error)
    else matched_keys empty
        CB-->>FL: raise RuntimeError(format_zero_match_error)
    else unexpected_keys present AND strict=True
        CB-->>FL: raise RuntimeError(format_unexpected_keys_error)
    else unexpected_keys present AND strict=False
        CB->>CB: warn + filter params_to_load to matched_keys only
        CB->>PT: load_state_dict(params_to_load, strict=False)
        PT-->>CB: IncompatibleKeys(missing_keys, [])
        CB-->>FL: log warnings for missing_keys
    else all keys matched
        CB->>PT: load_state_dict(model.params, strict=strict)
        PT-->>CB: success
    end
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into codex/pt-param-..." | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator Author

Addressed the Greptile findings in 1d023138e.

Updates in this follow-up:

  • added format_unexpected_keys_error() and now use rejection wording for persistence failures
  • fixed the Lightning callback so unexpected keys are rejected in strict=True, but only warned/ignored in strict=False
  • restored missing-key diagnostics for partial loads under strict=False
  • added from __future__ import annotations in nvflare.app_opt.pt.utils to avoid eager evaluation of the new annotations
  • expanded unit coverage for strict-mode Lightning behavior and partial-load missing-key logging

Validation rerun:

  • python3 -m py_compile nvflare/app_opt/pt/utils.py nvflare/app_opt/pt/model_persistence_format_manager.py nvflare/app_opt/lightning/api.py tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.py
  • pytest tests/unit_test/app_opt/pt/pt_param_validation_test.py tests/unit_test/app_opt/lightning/api_test.py -q
  • pytest tests/unit_test/app_opt/lightning tests/unit_test/app_opt/pt -q

@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai review again

@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai check the latest PR version

@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai review again

@holgerroth holgerroth marked this pull request as ready for review April 2, 2026 19:17
@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@holgerroth holgerroth requested a review from chesterxgchen April 2, 2026 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant