-
Notifications
You must be signed in to change notification settings - Fork 576
fix(stat): Caculate correct fitting stat when using default fparam and using share fitting. #5038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
fix(stat): Caculate correct fitting stat when using default fparam and using share fitting. #5038
Conversation
for more information, see https://pre-commit.ci
…an-Zhang/deepmd-kit into 1108_default_fparam_stat
📝 WalkthroughWalkthroughAdds default frame-parameter (fparam) exposure and population, extends fitting statistics computation with on-disk persistence and NumPy aggregation, and threads multitask-aware parameter-sharing probabilities and protection factors through model wrapper, trainer, and fitting layers. Changes
Sequence Diagram(s)sequenceDiagram
actor Trainer as Training
participant Wrapper as ModelWrapper
participant Fit as Fitting
participant Atom as DPAtomicModel
Training->>Training: compute model_key_prob_map & data_stat_protect
Training->>Wrapper: share_params(shared_links, model_key_prob_map, data_stat_protect)
activate Wrapper
Wrapper->>Wrapper: for each link compute frac_prob = prob_link/prob_base
Wrapper->>Fit: share_params(base_class, shared_level, model_prob=frac_prob, protection=data_stat_protect, resume)
deactivate Wrapper
Training->>Atom: compute_or_load_stat(stat_file_path)
activate Atom
Atom->>Fit: compute_input_stats(merged, protection, stat_file_path)
activate Fit
alt stat_file_path exists
Fit->>Fit: restore_fparam/aparam_from_file(stat_file_path)
else
Fit->>Fit: aggregate stats from data (NumPy), apply protection
Fit->>Fit: save_to_file_fparam/aparam(stat_file_path)
end
Fit->>Atom: return stats/default_fparam
deactivate Fit
Atom->>Training: provide default fparam for data requirements
deactivate Atom
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (8)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used📓 Path-based instructions (1)**/*.py📄 CodeRabbit inference engine (AGENTS.md)
Files:
🧬 Code graph analysis (5)source/tests/pt/test_fitting_stat.py (5)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/train/wrapper.py (1)
deepmd/pt/train/training.py (3)
deepmd/pt/model/task/fitting.py (3)
🪛 Ruff (0.14.3)source/tests/pt/test_fitting_stat.py114-114: Local variable Remove assignment to unused variable (F841) deepmd/pt/model/task/fitting.py160-160: Avoid specifying long messages outside the exception class (TRY003) 186-186: Avoid specifying long messages outside the exception class (TRY003) 349-349: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (13)
🔇 Additional comments (16)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
deepmd/pt/model/model/make_model.py (1)
9-9: Remove unused numpy import.The numpy import is not used anywhere in this file.
Apply this diff:
-import numpy as npdeepmd/pt/train/training.py (1)
636-642: Fix unnecessary f-string prefix.The assertion message on line 637 uses an f-string without any placeholders.
Apply this diff:
- assert np.allclose(_data_stat_protect, _data_stat_protect[0]), f"Model key 'data_stat_protect' must be the same in each branch when multitask!" + assert np.allclose(_data_stat_protect, _data_stat_protect[0]), "Model key 'data_stat_protect' must be the same in each branch when multitask!"The logic correctly validates consistency and propagates the protection value to parameter sharing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pt/model/atomic_model/dp_atomic_model.py(3 hunks)deepmd/pt/model/model/make_model.py(2 hunks)deepmd/pt/model/task/fitting.py(6 hunks)deepmd/pt/train/training.py(2 hunks)deepmd/pt/train/wrapper.py(2 hunks)deepmd/utils/env_mat_stat.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/train/wrapper.pydeepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/utils/env_mat_stat.pydeepmd/pt/train/training.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/model/make_model.py
🧬 Code graph analysis (5)
deepmd/pt/train/wrapper.py (1)
deepmd/pt/model/task/fitting.py (1)
share_params(66-128)
deepmd/pt/model/atomic_model/dp_atomic_model.py (4)
deepmd/pt/model/model/make_model.py (2)
has_default_fparam(530-532)get_default_fparam(535-536)deepmd/pt/model/task/fitting.py (3)
has_default_fparam(599-601)get_default_fparam(603-604)compute_input_stats(208-269)deepmd/pd/model/atomic_model/dp_atomic_model.py (2)
has_default_fparam(414-416)wrapped_sampler(387-397)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
has_default_fparam(138-140)
deepmd/pt/train/training.py (4)
deepmd/pt/model/task/fitting.py (4)
share_params(66-128)get_default_fparam(603-604)has_default_fparam(599-601)get_dim_fparam(595-597)deepmd/pt/train/wrapper.py (1)
share_params(63-139)deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
get_default_fparam(355-356)has_default_fparam(351-353)get_dim_fparam(347-349)deepmd/utils/data.py (1)
DataRequirementItem(745-825)
deepmd/pt/model/task/fitting.py (5)
deepmd/utils/path.py (13)
DPPath(28-158)mkdir(149-158)mkdir(270-282)mkdir(472-490)save_numpy(70-77)save_numpy(200-211)save_numpy(358-370)load_numpy(50-57)load_numpy(180-188)load_numpy(335-343)is_dir(115-116)is_dir(249-251)is_dir(439-445)deepmd/utils/env_mat_stat.py (3)
StatItem(26-98)compute_avg(58-73)compute_std(75-98)deepmd/pt/utils/utils.py (6)
to_numpy_array(224-224)to_numpy_array(228-228)to_numpy_array(231-247)to_torch_tensor(251-251)to_torch_tensor(255-255)to_torch_tensor(258-276)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(355-356)deepmd/pt/model/model/make_model.py (1)
get_default_fparam(535-536)
deepmd/pt/model/model/make_model.py (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(355-356)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(603-604)deepmd/pt/model/network/network.py (1)
Tensor(36-37)
🪛 Ruff (0.14.3)
deepmd/pt/train/training.py
637-637: f-string without any placeholders
Remove extraneous f prefix
(F541)
deepmd/pt/model/task/fitting.py
269-270: Expected an indented block after if statement
(invalid-syntax)
272-272: unindent does not match any outer indentation level
(invalid-syntax)
272-272: Expected a statement
(invalid-syntax)
272-272: Expected a statement
(invalid-syntax)
272-273: Expected a statement
(invalid-syntax)
273-273: Unexpected indentation
(invalid-syntax)
297-297: unindent does not match any outer indentation level
(invalid-syntax)
298-298: Unexpected indentation
(invalid-syntax)
304-304: unindent does not match any outer indentation level
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (13)
deepmd/utils/env_mat_stat.py (1)
51-56: LGTM!The scalar multiplication operator correctly scales all statistical components for probability-weighted aggregation in multitask training. The implementation properly supports the weighted averaging workflow where statistics from multiple models are combined using probability weights.
deepmd/pt/model/model/make_model.py (1)
534-536: LGTM!The method correctly delegates to the atomic model and follows the established pattern for other similar accessors in this class.
deepmd/pt/train/wrapper.py (1)
63-63: LGTM!The extended signature correctly supports probability-weighted parameter sharing for multitask training. The parameters align with the updated
share_paramsimplementation in the fitting net.deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
329-337: LGTM!The logic correctly populates missing fparam with default values when available. The check for both
"find_fparam"and"fparam"ensures proper handling of data loading states.
342-342: LGTM!The stat_file_path propagation enables proper persistence of fparam/aparam statistics, and the
get_default_fparammethod correctly delegates to the fitting net.Also applies to: 355-356
deepmd/pt/train/training.py (2)
619-632: LGTM!The model probability calculation correctly supports both explicit configuration and data-driven defaults, with proper normalization and validation to ensure a valid probability distribution.
1344-1351: LGTM!The default fparam handling correctly retrieves and converts the default value from the model, passing it to the data requirement with proper type conversion.
deepmd/pt/model/task/fitting.py (6)
66-128: LGTM!The extended
share_paramscorrectly implements probability-weighted parameter sharing for multitask training. The logic properly accumulates weighted statistics for fparam/aparam buffers and links them to the base class.
130-206: LGTM!The persistence methods correctly save and restore fparam/aparam statistics using numpy arrays, with proper path handling and logging.
208-266: LGTM!The fparam statistics computation correctly implements the load-or-compute pattern with proper persistence and type conversions.
304-310: LGTM!The
get_statsmethod properly validates that statistics have been computed before returning them.
603-604: LGTM!The method correctly exposes the default fparam tensor and aligns with the existing
has_default_fparamaccessor.
11-11: LGTM!The new imports are properly used throughout the file for type hints and statistics handling.
Also applies to: 45-50
for more information, see https://pre-commit.ci
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #5038 +/- ##
========================================
Coverage 84.19% 84.19%
========================================
Files 709 709
Lines 70216 70326 +110
Branches 3621 3618 -3
========================================
+ Hits 59116 59213 +97
- Misses 9933 9945 +12
- Partials 1167 1168 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| def __mul__(self, scalar: float) -> "StatItem": | ||
| return StatItem( | ||
| number=self.number * scalar, | ||
| sum=self.sum * scalar, | ||
| squared_sum=self.squared_sum * scalar, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some type issues here:
number is int and scalar is float, int * float = float, so it cannot be assigned to number (expected an int),
In this PR:
stat_fileand loading fitting stat fromstat_filedefault_fparamshare_fittingin multitask mode.log.info.Summary by CodeRabbit
New Features
Refactor
Tests