-
Notifications
You must be signed in to change notification settings - Fork 585
feat: use num_epoch to set num_steps #5148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for specifying training duration using num_epoch instead of requiring explicit numb_steps. The implementation computes num_steps automatically based on dataset size and sampling probabilities, supporting both single-task and multi-task training modes.
Changes:
- Made
numb_stepsparameter optional and addednum_epochparameter in training configuration - Implemented automatic computation of training steps from epoch count and dataset characteristics
- Added comprehensive tests for sampling stability in both single-task and multi-task scenarios
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| source/tests/pt/test_sampler.py | Added test infrastructure (_SerialPool, helper methods) and new tests for single-task and multi-task sampling stability to validate the num_epoch functionality |
| deepmd/utils/argcheck.py | Updated training parameter documentation and made numb_steps optional while adding num_epoch parameter definition |
| deepmd/pt/train/training.py | Implemented compute_total_numb_batch and resolve_model_prob functions, added logic to compute num_steps from num_epoch when numb_steps is not provided |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Warning Rate limit exceeded
⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds utilities and flow to derive per-task probabilities and total training steps from per-task batch counts and epoch targets; exposes Changes
Sequence Diagram(s)sequenceDiagram
participant Config
participant Trainer
participant Utils
participant Sampler
participant DataLoader
Config->>Trainer: provide training params (model_keys, numb_steps/num_epoch/num_epoch_dict, model_prob)
Trainer->>Utils: compute per-task totals (compute_total_numb_batch)
alt num_epoch_dict provided
Trainer->>Utils: resolve model_prob & per-task steps (resolve_model_prob_from_epochs)
else multi-task and num_epoch_dict not provided
Trainer->>Utils: resolve model_prob from config/counts (resolve_model_prob)
else single-task and num_epoch provided
Trainer->>Utils: compute total_numb_batch -> derive num_steps
end
Trainer->>Trainer: finalize num_steps and model_prob
Trainer->>Sampler: initialize sampler with resolved model_prob and num_steps
Sampler->>DataLoader: request batches according to probabilities
DataLoader->>Trainer: yield batches into training loop
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 275-292: resolve_model_prob currently treats an empty dict as
"provided" and produces an all-zero vector; update resolve_model_prob so that if
model_prob_config is None or empty (len == 0) it falls back to using
model_training_data lengths for initial weights, otherwise read values from
model_prob_config, and validate each provided value is finite and non-negative
(use numpy.isfinite and value >= 0), raising a clear ValueError naming the
offending model_key for invalid entries; after building the weight vector (from
config or lengths) ensure sum_prob > 0 and then return the normalized
probabilities; reference resolve_model_prob, model_keys, model_prob_config,
model_training_data, and sum_prob in your changes.
In @deepmd/utils/argcheck.py:
- Around line 3298-3305: The argcheck change made "numb_steps" optional but the
Paddle trainer still uses direct dict access training_params["numb_steps"],
causing runtime KeyError; either (A) update the Paddle training code that
references training_params["numb_steps"] to use
training_params.get("numb_steps") and implement the same fallback logic used by
the PyTorch/TensorFlow trainers (fall back to num_epoch or a default) or (B)
enforce cross-parameter validation in argcheck so that at least one of
"numb_steps" or "num_epoch" is required (add a validator that raises if both are
missing); pick one approach and apply it consistently to the Paddle trainer and
argcheck to avoid the KeyError.
🧹 Nitpick comments (3)
deepmd/pt/train/training.py (2)
476-525: Derivingnum_stepsfromnum_epochis clean and well-guarded.Checks for missing both, non-positive
num_epoch, and non-positivetotal_numb_batchare good. The warning path when both are set matches the doc (“numb_steps takes precedence”).
Minor:self.model_probgets resolved here in multi-task, making the later “Get model prob” block (Line 733+) redundant.
733-739: Redundantmodel_probresolution block (likely dead).After the new “Resolve training steps” section,
self.model_probshould always be set for multi-task (or an exception is thrown), so this guard probably never runs now. Consider removing to avoid diverging behaviors in the future.source/tests/pt/test_sampler.py (1)
84-155: Helper methods are fine; consider mirroring production validation in_normalize_probs/_compute_total_numb_batch.Right now tests assume no zero-prob entries (division-by-zero would explode). If you expect users might set zero weights, you could align the test helper with
compute_total_numb_batch()’s “filterprobs>0” behavior to avoid brittle tests.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_sampler.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/pt/utils/dataloader.py (1)
DpLoaderSet(65-243)deepmd/dpmodel/common.py (1)
to_numpy_array(106-128)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py
258-258: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
270-272: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
505-507: Avoid specifying long messages outside the exception class
(TRY003)
509-509: Avoid specifying long messages outside the exception class
(TRY003)
511-511: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_sampler.py
338-338: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (42)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (4)
deepmd/utils/argcheck.py (1)
3216-3225: Docs are good, but double-check multi-task wording matches implementation.
total_numb_batchis described as a model-prob-weighted sum across tasks; that matches the newtraining.pylogic (ceil of weighted sum of per-task totals).
Minor: consider explicitly stating what happens whentraining.model_probis omitted (uniform? proportional to dataset size?), sincetraining.pynow needs a deterministic fallback.deepmd/pt/train/training.py (1)
252-274:compute_total_numb_batch()logic looks right; consider behavior when sampler has no weights.The validation +
ceil(max(nbatches/prob))matches the intended formula. One potential fragility: callers assumeself.training_dataloader.sampler.weightsexists; if a non-weighted sampler is used, this will crash before your validation runs. If that’s impossible by construction, OK—otherwise add a fallback (e.g., uniform weights).source/tests/pt/test_sampler.py (2)
36-55: CI-friendly multiprocessing workaround is reasonable; keep it tightly scoped.The
_SerialPool+pytest.MonkeyPatchapproach should avoid SemLock/CUDA init issues and is correctly undone intearDown().Also applies to: 81-83
227-395: Good coverage for determinism and multi-task distributions.These tests exercise:
- single-task: empirical SID distribution ~= target probs, and exact equality between two runs when RNG seeds + num_steps match
- multi-task: model selection frequencies ~=
model_prob, and per-task SID distributions ~= per-task sampler probsThis is a solid regression net for the new
num_epoch -> num_stepsderivation behavior.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5148 +/- ##
==========================================
- Coverage 81.95% 81.83% -0.12%
==========================================
Files 713 715 +2
Lines 72985 73472 +487
Branches 3617 3616 -1
==========================================
+ Hits 59812 60129 +317
- Misses 12010 12182 +172
+ Partials 1163 1161 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 275-296: resolve_model_prob currently treats an empty dict ({})
the same as None because it checks "if model_prob_config:" which makes "{}" fall
back to using len(model_training_data) instead of an explicit uniform
distribution; change the conditional to check identity against None (if
model_prob_config is None:) so only a true None triggers the
proportional-to-training-data behavior, and add a branch for an explicit empty
dict (or empty mapping) to assign uniform weights across model_keys (or log that
{} means uniform) while preserving validations in resolve_model_prob.
- Around line 252-274: In compute_total_numb_batch, explicitly validate
sampler_weights for finiteness and negativity: after converting to weights
(np.asarray(...)), raise a ValueError if not np.all(np.isfinite(weights)) or if
np.any(weights < 0); also check that weight_sum is finite
(np.isfinite(weight_sum)) before using it. Keep the existing checks (1D,
non-empty, positive sum) but ensure non-finite and negative entries are rejected
up-front to avoid producing misleading totals.
In @deepmd/utils/argcheck.py:
- Around line 3216-3232: The docstring for doc_numb_steps currently claims it
takes precedence over num_epoch globally; update it to scope that behavior to
the PyTorch backend (or otherwise clarify backend-specific behavior). Modify the
doc_numb_steps string to prepend or include "(Supported Backend: PyTorch)" and
rephrase the precedence sentence to indicate this rule applies only for the
PyTorch backend rather than universally, and ensure doc_num_epoch remains
consistent about backend-specific computations if needed.
🧹 Nitpick comments (2)
deepmd/pt/train/training.py (1)
736-742: Optional: remove the now-redundant model_prob re-resolution block.Since
self.model_probis already set during step resolution for multi-task, this should be dead code in normal flows.source/tests/pt/test_sampler.py (1)
91-109: Test helper is intentionally simplified; keep it aligned if training math changes again.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.
Applied to files:
source/tests/pt/test_sampler.py
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py
258-258: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
270-272: Avoid specifying long messages outside the exception class
(TRY003)
289-289: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
509-511: Avoid specifying long messages outside the exception class
(TRY003)
513-513: Avoid specifying long messages outside the exception class
(TRY003)
515-515: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_sampler.py
100-102: Avoid specifying long messages outside the exception class
(TRY003)
104-107: Avoid specifying long messages outside the exception class
(TRY003)
355-355: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (39)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (3)
deepmd/pt/train/training.py (1)
142-147: Good: num_steps/num_epoch are now parsed via.get(), and model_prob is initialized.deepmd/utils/argcheck.py (1)
3304-3316: Schema updates look consistent:numb_stepsoptional w/ aliases, plus newnum_epoch.source/tests/pt/test_sampler.py (1)
31-78:_SerialPoolcorrectly implements the requiredPoolAPI.The monkeypatch is safe and complete.
_SerialPoolprovides all methods used bypt_dataloader(context manager protocol andmap()), and the implementation matches the behavior ofmultiprocessing.dummy.Pool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 252-305: In resolve_model_prob: when model_prob_config is
provided, do not silently leave unspecified model_keys at zero; validate that
every model_key is present in model_prob_config (or explicitly allowed), and if
any are missing raise a ValueError (include the missing key names) instead of
filling zeros; implement this check before building model_prob (using model_keys
and model_prob_config) and only fall back to using model_training_data counts
when model_prob_config is None or empty, otherwise require completeness.
🧹 Nitpick comments (3)
deepmd/tf/entrypoints/change_bias.py (1)
190-231: Validatenbatchestoo (and reconsiderstop_batch=0fallback).
compute_total_numb_batch()validates probabilities, butnbatchescan still be non-1D / non-finite / negative, which can surface as opaque NumPy failures. Also, silently defaultingstop_batchto0if neithertraining.numb_stepsnortraining.num_epochis set may breaktrainer.build()if it assumes a positive step count.Proposed tightening for nbatches validation (local, minimal)
def compute_total_numb_batch(nbatches, sys_probs) -> int: weights = np.asarray(sys_probs, dtype=np.float64) if weights.ndim != 1: raise ValueError("Sampler probabilities must be 1D.") if weights.size == 0: raise ValueError("Sampler probabilities are empty.") if not np.all(np.isfinite(weights)): raise ValueError("Sampler probabilities must be finite.") if np.any(weights < 0.0): raise ValueError("Sampler probabilities must be non-negative.") weight_sum = float(np.sum(weights)) if weight_sum <= 0.0: raise ValueError("Sampler probabilities must sum to a positive value.") probs = weights / weight_sum - nbatches = np.asarray(nbatches, dtype=np.float64) + nbatches = np.asarray(nbatches, dtype=np.float64) + if nbatches.ndim != 1: + raise ValueError("Number of batches must be 1D.") + if nbatches.size == 0: + raise ValueError("Number of batches is empty.") + if not np.all(np.isfinite(nbatches)): + raise ValueError("Number of batches must be finite.") + if np.any(nbatches < 0.0): + raise ValueError("Number of batches must be non-negative.") if nbatches.shape[0] != probs.shape[0]: raise ValueError("Number of batches and sampler probabilities must match.") valid = probs > 0.0 if not np.any(valid): raise ValueError( "Sampler probabilities must contain at least one positive entry." ) return int(np.ceil(np.max(nbatches[valid] / probs[valid])))deepmd/tf/entrypoints/train.py (1)
257-313: Good precedence + errors; addnbatchesvalidation (and avoid helper drift).The
numb_steps/num_epochresolution is solid and the warning on both-set is helpful. Same suggestion as elsewhere: validatenbatches(1D/finite/non-negative) to prevent odd NumPy errors, and consider centralizingcompute_total_numb_batchto avoid cross-backend drift.deepmd/utils/argcheck.py (1)
3308-3320: Schema change looks right; consider aligning “neither set” behavior across entrypoints.
numb_stepsis now optional andnum_epochis introduced, but enforcement differs across entrypoints (some raise;change_bias.pyfalls back to 0). Consider making this consistent (either always raise, or explicitly document the one-off).
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pd/train/training.pydeepmd/pt/train/training.pydeepmd/tf/entrypoints/change_bias.pydeepmd/tf/entrypoints/train.pydeepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (3)
deepmd/pd/train/training.py (2)
deepmd/pt/train/training.py (2)
compute_total_numb_batch(252-277)resolve_model_prob(279-304)deepmd/tf/entrypoints/train.py (1)
compute_total_numb_batch(257-279)
deepmd/pt/train/training.py (3)
deepmd/pd/train/training.py (2)
compute_total_numb_batch(210-232)resolve_model_prob(234-254)deepmd/tf/entrypoints/change_bias.py (1)
compute_total_numb_batch(190-212)deepmd/tf/entrypoints/train.py (1)
compute_total_numb_batch(257-279)
deepmd/tf/entrypoints/train.py (3)
deepmd/pd/train/training.py (1)
compute_total_numb_batch(210-232)deepmd/pt/train/training.py (1)
compute_total_numb_batch(252-277)deepmd/tf/entrypoints/change_bias.py (1)
compute_total_numb_batch(190-212)
🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/change_bias.py
193-193: Avoid specifying long messages outside the exception class
(TRY003)
195-195: Avoid specifying long messages outside the exception class
(TRY003)
197-197: Avoid specifying long messages outside the exception class
(TRY003)
199-199: Avoid specifying long messages outside the exception class
(TRY003)
202-202: Avoid specifying long messages outside the exception class
(TRY003)
206-206: Avoid specifying long messages outside the exception class
(TRY003)
209-211: Avoid specifying long messages outside the exception class
(TRY003)
219-219: Avoid specifying long messages outside the exception class
(TRY003)
222-222: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
213-213: Avoid specifying long messages outside the exception class
(TRY003)
215-215: Avoid specifying long messages outside the exception class
(TRY003)
217-217: Avoid specifying long messages outside the exception class
(TRY003)
219-219: Avoid specifying long messages outside the exception class
(TRY003)
222-222: Avoid specifying long messages outside the exception class
(TRY003)
226-226: Avoid specifying long messages outside the exception class
(TRY003)
229-231: Avoid specifying long messages outside the exception class
(TRY003)
248-248: Avoid specifying long messages outside the exception class
(TRY003)
250-250: Avoid specifying long messages outside the exception class
(TRY003)
253-253: Avoid specifying long messages outside the exception class
(TRY003)
469-471: Avoid specifying long messages outside the exception class
(TRY003)
473-473: Avoid specifying long messages outside the exception class
(TRY003)
475-475: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/train/training.py
258-258: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
271-271: Avoid specifying long messages outside the exception class
(TRY003)
274-276: Avoid specifying long messages outside the exception class
(TRY003)
298-298: Avoid specifying long messages outside the exception class
(TRY003)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
303-303: Avoid specifying long messages outside the exception class
(TRY003)
518-520: Avoid specifying long messages outside the exception class
(TRY003)
522-522: Avoid specifying long messages outside the exception class
(TRY003)
524-524: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/tf/entrypoints/train.py
260-260: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
266-266: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
273-273: Avoid specifying long messages outside the exception class
(TRY003)
276-278: Avoid specifying long messages outside the exception class
(TRY003)
286-288: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
292-294: Avoid specifying long messages outside the exception class
(TRY003)
299-299: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (10)
deepmd/tf/entrypoints/train.py (1)
15-16: NumPy import looks fine.deepmd/pt/train/training.py (3)
142-147: Readingnumb_steps/num_epochvia.get()is a good compat move.
746-752: Guardedmodel_probresolution is fine (keeps later paths consistent).
489-538: No action needed —.indexis correctly per-system batch counts.
training_data.indexcontains batch counts per system (populated vialen(system_dataloader)for each system), not sample indices. The sampler weights are also per-system, ensuring shape consistency. The math incompute_total_numb_batch()is correct.deepmd/utils/argcheck.py (2)
3216-3232: Docs fornumb_stepsprecedence +num_epochderivation look consistent.
3278-3282:model_probdefault behavior is now clear; good.deepmd/pd/train/training.py (4)
135-136: LGTM!The change to use
.get()for bothnumb_stepsandnum_epochallows either to be optionally specified, with validation deferred to later in the initialization.
210-232: LGTM!The
compute_total_numb_batchfunction correctly validates sampler weights and computes the total number of batches needed for one epoch. The implementation is consistent with the reference indeepmd/pt/train/training.py.
467-488: LGTM!The logic for deriving
num_stepsfromnum_epochis well-implemented with proper validation:
- Clear error when neither
numb_stepsnornum_epochis provided- Validation that
num_epochandtotal_numb_batchare positive- Ceiling function ensures at least the specified number of epochs are trained
- Helpful warning when both options are set
443-446: No issues found. The code is correct.
training_data.indexis a list containing the number of batches for each system dataloader (populated vialen(system_dataloader)at dataloader construction). This is exactly whatcompute_total_numb_batchexpects—an array-like of batch counts to be paired withsampler_weights. The shape check in the function (line 216) validates that the number of batch counts matches the number of weights, confirming the usage is safe and correct.
|
Furthermore, I recommend to implement a new option: in multi-task mode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
279-309: Improved validation over other backends, with a minor redundancy.The missing-key validation (lines 286-290) and info logging (lines 295-299) are valuable enhancements compared to the PD/TF implementations. However, after the missing-key check, the condition on line 292 (
if model_key in model_prob_config) is always true and can be simplified.♻️ Optional: Remove redundant conditional
for ii, model_key in enumerate(model_keys): - if model_key in model_prob_config: - model_prob[ii] = float(model_prob_config[model_key]) + model_prob[ii] = float(model_prob_config[model_key])
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.pydeepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/tf/entrypoints/train.py (1)
compute_total_numb_batch(257-279)deepmd/pd/train/training.py (2)
compute_total_numb_batch(210-232)resolve_model_prob(234-254)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py
258-258: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
271-271: Avoid specifying long messages outside the exception class
(TRY003)
274-276: Avoid specifying long messages outside the exception class
(TRY003)
288-290: Avoid specifying long messages outside the exception class
(TRY003)
303-303: Avoid specifying long messages outside the exception class
(TRY003)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
308-308: Avoid specifying long messages outside the exception class
(TRY003)
523-525: Avoid specifying long messages outside the exception class
(TRY003)
527-527: Avoid specifying long messages outside the exception class
(TRY003)
529-529: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (25)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (5)
deepmd/pt/train/training.py (5)
142-142: LGTM!The
model_probattribute initialization is appropriate for deferred computation in multi-task scenarios.
145-146: LGTM!Using
.get()for both parameters correctly handles the case where eithernumb_stepsornum_epochmay be omitted from the configuration.
252-277: LGTM!The
compute_total_numb_batchfunction correctly validates inputs and computes the total batch count needed to cover all systems given their sampling probabilities. The implementation aligns with the analogous functions in the PD and TF training modules.
494-542: LGTM!The training steps resolution logic correctly handles both single-task and multi-task modes:
- Single-task computes
total_numb_batchfrom sampler weights- Multi-task computes a weighted sum of per-task batch counts using model probabilities
- Validation ensures exactly one of
numb_stepsornum_epochis effectively used- The warning when both are set provides clear feedback about precedence
751-756: LGTM!The guarded
model_probresolution provides a defensive fallback. Given the current control flow,self.model_probis always set at line 513 for multi-task mode, so this branch serves as a safety net for potential future refactoring or edge cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In @deepmd/pd/train/training.py:
- Around line 235-256: In resolve_model_prob, ensure model_prob_config cannot
silently omit tasks: if model_prob_config is provided, validate that every
model_key in model_keys is present (use set(model_keys) -
set(model_prob_config.keys()) to find missing keys) and raise a ValueError
listing the missing keys (mention model_keys and model_prob_config in the
message) instead of filling zeros; keep the existing finite/negative/sum checks
and normalization so callers like num_epoch_dict won't end up dividing by zero
or skipping tasks.
- Around line 211-234: compute_total_numb_batch currently only validates
sampler_weights; validate the numb_batches input similarly: convert numb_batches
to a 1D numpy array (nbatches = np.asarray(numb_batches, dtype=np.float64)) and
check it is 1D and the same length as weights, contains only finite values
(np.isfinite), and contains no negative entries; additionally ensure there is at
least one positive nbatch where the corresponding prob > 0 before computing the
ceil(max(...)) to avoid NaNs/infs and raise ValueError with clear messages
referencing numb_batches/nbatches when checks fail.
In @deepmd/pt/train/training.py:
- Around line 550-554: The log is zipping self.model_keys (all tasks) with
per_task_steps (only tasks where epoch_value is not None), causing wrong
pairings; change the code that builds per_task_steps to also collect the
corresponding task keys (e.g., keys_for_steps) using the same filter, then use
zip(keys_for_steps, per_task_steps) (or build a dict from those paired lists) in
the logging call so each model key is matched to its correct ceil'd step value;
update any variable names referenced (self.model_keys, per_task_steps)
accordingly.
- Around line 538-545: The loop computing per_task_steps can raise
ZeroDivisionError if self.model_prob[ii] == 0 and ValueError if per_task_steps
stays empty; update the logic in the block that iterates over
self.model_keys/num_epoch_dict to skip entries where epoch_value is None or
self.model_prob[ii] == 0 (i.e., do not perform steps_i = epoch_value *
per_task_total[ii] / self.model_prob[ii] for those cases), collect valid steps_i
into per_task_steps, and after the loop set self.num_steps =
int(np.ceil(np.max(per_task_steps))) only if per_task_steps is non-empty else
set self.num_steps = 0 (or another safe default); refer to symbols:
per_task_steps, self.model_keys, self.num_epoch_dict, self.model_prob,
per_task_total, and self.num_steps.
🧹 Nitpick comments (4)
source/tests/pt/test_sampler.py (2)
46-78: Preferself.addCleanup(self._monkeypatch.undo)to avoid patch leaks ifsetUpfails mid-way.
Right now, a failure after Line 49 but before tearDown registration can leavept_dataloader.Poolpatched for subsequent tests.Proposed fix
def setUp(self) -> None: self._monkeypatch = pytest.MonkeyPatch() # Avoid SemLock/CUDA initialization failures in restricted CI by forcing a serial pool. self._monkeypatch.setattr(pt_dataloader, "Pool", _SerialPool) + self.addCleanup(self._monkeypatch.undo) @@ def tearDown(self) -> None: - self._monkeypatch.undo() + self._monkeypatch.undo()
244-519: Sampling stability tests may be a bit tolerance-sensitive; consideratolbynum_steps(optional).
Using a fixedatol=0.1can be flaky if CI runtime changes cause small effective sample counts; scaling tolerance bysqrt(p(1-p)/n)(or increasingnum_epoch) tends to be more robust.deepmd/pt/train/training.py (1)
292-294: Minor redundancy in conditional check.The check
if model_key in model_prob_configat line 293 is redundant since missing keys are already caught and raised at lines 287-291. The loop at line 292 will only execute when all keys are present.♻️ Suggested simplification
for ii, model_key in enumerate(model_keys): - if model_key in model_prob_config: - model_prob[ii] = float(model_prob_config[model_key]) + model_prob[ii] = float(model_prob_config[model_key])doc/train/multi-task-training.md (1)
84-91: Good documentation, minor terminology inconsistency.The documentation clearly explains the
num_epoch_dictfeature and its precedence overnum_epoch.One minor issue: Line 86 uses "pretrained" while line 102 uses "pre-trained". For consistency within the document, consider using the same variant throughout.
📝 Suggested fix
- where a data-rich pretrained model is jointly trained with a data-scarce downstream task. + where a data-rich pre-trained model is jointly trained with a data-scarce downstream task.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pd/train/training.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pydoc/train/multi-task-training.mdsource/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.
Applied to files:
source/tests/pt/test_sampler.py
🧬 Code graph analysis (1)
source/tests/pt/test_sampler.py (2)
deepmd/pt/utils/dataloader.py (3)
DpLoaderSet(65-243)get_weighted_sampler(266-290)get_sampler_from_params(293-306)deepmd/utils/data_system.py (1)
set_sys_probs(413-434)
🪛 LanguageTool
doc/train/multi-task-training.md
[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...
(EN_WORD_COHERENCY)
🪛 Ruff (0.14.10)
source/tests/pt/test_sampler.py
100-102: Avoid specifying long messages outside the exception class
(TRY003)
104-107: Avoid specifying long messages outside the exception class
(TRY003)
355-355: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/train/training.py
259-259: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
265-265: Avoid specifying long messages outside the exception class
(TRY003)
268-268: Avoid specifying long messages outside the exception class
(TRY003)
272-272: Avoid specifying long messages outside the exception class
(TRY003)
275-277: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
304-304: Avoid specifying long messages outside the exception class
(TRY003)
306-306: Avoid specifying long messages outside the exception class
(TRY003)
309-309: Avoid specifying long messages outside the exception class
(TRY003)
527-529: Avoid specifying long messages outside the exception class
(TRY003)
534-536: Avoid specifying long messages outside the exception class
(TRY003)
552-552: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
557-560: Avoid specifying long messages outside the exception class
(TRY003)
563-563: Avoid specifying long messages outside the exception class
(TRY003)
565-567: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
214-214: Avoid specifying long messages outside the exception class
(TRY003)
216-216: Avoid specifying long messages outside the exception class
(TRY003)
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
227-227: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
249-249: Avoid specifying long messages outside the exception class
(TRY003)
251-251: Avoid specifying long messages outside the exception class
(TRY003)
254-254: Avoid specifying long messages outside the exception class
(TRY003)
473-475: Avoid specifying long messages outside the exception class
(TRY003)
480-482: Avoid specifying long messages outside the exception class
(TRY003)
498-498: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
503-506: Avoid specifying long messages outside the exception class
(TRY003)
509-509: Avoid specifying long messages outside the exception class
(TRY003)
511-513: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (32)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (8)
deepmd/pd/train/training.py (1)
135-137: Nice:numb_stepsis now optional and epoch-based scheduling is surfaced (num_epoch,num_epoch_dict).
This makes the Trainer consistent with the new arg schema and enables the requested behavior.source/tests/pt/test_sampler.py (1)
79-164: Test helpers look consistent with the newnum_epoch → num_stepsmath.
The “derived steps == explicit steps” assertions are a good regression guard for deterministic sampling.deepmd/utils/argcheck.py (2)
3216-3252: Docs clearly define precedence and the multi-task meaning of “epoch”.
This should reduce user confusion around fractional epochs and weighted sampling.
3321-3348: Thenumb_stepsaliases (stop_batch,num_steps) are safe—no collision risk detected.The training code correctly reads the resolved
numb_stepsfield fromtraining_params.get("numb_steps")after alias resolution in the Argument parser. Thenum_stepsalias does not conflict with the internalself.num_stepsattribute, which stores the retrieved value separately from the config field name.deepmd/pt/train/training.py (4)
142-147: LGTM!The initialization of
model_proband the use of.get()for optional parameters is appropriate.
253-278: LGTM!The
compute_total_numb_batchfunction correctly computes the total number of batches needed to ensure each system completes at least one epoch. The validation logic is thorough and the formulamax(nbatches[valid] / probs[valid])ensures proper coverage.
575-582: LGTM!Appropriate warning when both
numb_stepsandnum_epoch/num_epoch_dictare set, with clear precedence given tonumb_steps.
791-796: LGTM!The guard prevents redundant re-resolution of
model_probwhen it was already computed during the training steps resolution phase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
deepmd/pd/train/training.py (2)
211-241: LGTM! Comprehensive validation for batch count computation.The validation logic correctly handles edge cases (empty arrays, non-finite values, zero weights). The formula
ceil(max(nbatches[valid] / probs[valid]))correctly computes the minimum number of global steps needed to ensure each system is sampled according to its weight.Note: This function is duplicated across PT, PD, and TF implementations. Consider extracting to a shared utility module in the future to reduce maintenance burden.
243-268: Consider adding a log message when defaulting to system counts.The logic is correct. However, the PT implementation (lines 304-308) logs an info message when
model_prob_configis not set, informing users that it defaults to the number of systems per task. Adding the same log here would improve consistency and help users understand the behavior.♻️ Suggested enhancement
else: + log.info( + "training.model_prob is not set or empty; defaulting to the " + "number of systems per task." + ) for ii, model_key in enumerate(model_keys): model_prob[ii] = float(len(model_training_data[model_key]))deepmd/pt/train/training.py (1)
804-809: Consider adding a clarifying comment for this defensive block.In the current flow,
self.model_probis already set at lines 522-526 for multi-task mode, so this block would not execute. However, it serves as defensive code. A brief comment would help future maintainers understand its purpose.♻️ Suggested enhancement
# Get model prob for multi-task + # Defensive fallback: resolve model_prob if not already set during initialization + # (e.g., if future code paths skip the earlier resolution block) if self.multi_task and self.model_prob is None: self.model_prob = resolve_model_prob( self.model_keys, training_params.get("model_prob"), training_data, )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.pydeepmd/pt/train/training.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (3)
deepmd/pd/train/training.py (2)
compute_total_numb_batch(211-241)resolve_model_prob(243-268)deepmd/tf/entrypoints/train.py (1)
compute_total_numb_batch(257-279)deepmd/tf/entrypoints/change_bias.py (1)
compute_total_numb_batch(190-212)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py
259-259: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
265-265: Avoid specifying long messages outside the exception class
(TRY003)
268-268: Avoid specifying long messages outside the exception class
(TRY003)
272-272: Avoid specifying long messages outside the exception class
(TRY003)
274-274: Avoid specifying long messages outside the exception class
(TRY003)
276-276: Avoid specifying long messages outside the exception class
(TRY003)
278-278: Avoid specifying long messages outside the exception class
(TRY003)
280-280: Avoid specifying long messages outside the exception class
(TRY003)
283-285: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Avoid specifying long messages outside the exception class
(TRY003)
314-314: Avoid specifying long messages outside the exception class
(TRY003)
317-317: Avoid specifying long messages outside the exception class
(TRY003)
535-537: Avoid specifying long messages outside the exception class
(TRY003)
542-544: Avoid specifying long messages outside the exception class
(TRY003)
551-553: Avoid specifying long messages outside the exception class
(TRY003)
558-560: Avoid specifying long messages outside the exception class
(TRY003)
570-573: Avoid specifying long messages outside the exception class
(TRY003)
576-576: Avoid specifying long messages outside the exception class
(TRY003)
578-580: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
214-214: Avoid specifying long messages outside the exception class
(TRY003)
216-216: Avoid specifying long messages outside the exception class
(TRY003)
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
227-227: Avoid specifying long messages outside the exception class
(TRY003)
229-229: Avoid specifying long messages outside the exception class
(TRY003)
231-231: Avoid specifying long messages outside the exception class
(TRY003)
233-233: Avoid specifying long messages outside the exception class
(TRY003)
235-235: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
252-254: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
486-488: Avoid specifying long messages outside the exception class
(TRY003)
493-495: Avoid specifying long messages outside the exception class
(TRY003)
502-504: Avoid specifying long messages outside the exception class
(TRY003)
509-511: Avoid specifying long messages outside the exception class
(TRY003)
521-524: Avoid specifying long messages outside the exception class
(TRY003)
527-527: Avoid specifying long messages outside the exception class
(TRY003)
529-531: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
🔇 Additional comments (8)
deepmd/pd/train/training.py (3)
135-137: LGTM!Proper use of
.get()for optional config parameters, allowing eithernumb_stepsornum_epoch/num_epoch_dictto drive step calculation.
481-546: LGTM! Comprehensive step resolution with proper precedence handling.The logic correctly:
- Prioritizes
num_epoch_dictfor multi-task (allowing per-task epoch targets)- Falls back to
num_epochfor uniform epoch count- Raises clear errors when neither is provided
- Guards against division by zero when computing per-task steps
- Logs computed values for transparency
- Warns when both
numb_stepsand epoch configs are provided
453-480: No issues found with batch count calculation.The code correctly uses
training_data.indexto obtain batch counts per system. Both single-task and multi-task modes properly pass these counts tocompute_total_numb_batch, which validates and processes them appropriately.deepmd/pt/train/training.py (5)
142-142: LGTM!Good defensive initialization of
self.model_probtoNone. This ensures the attribute exists before the multi-task resolution block and allows the guarded check at line 804.
145-147: LGTM!Consistent with PD implementation for reading optional config parameters.
253-286: LGTM!Comprehensive validation matching the PD implementation. The logic correctly computes the minimum number of global steps needed to cover all batches according to their weights.
288-318: LGTM!Good enhancement with the info log (lines 304-308) when
model_prob_configis not set. This helps users understand the default behavior.
503-595: LGTM!Consistent implementation with the PD version. The step resolution logic is comprehensive with proper validation, logging, and warning when both
numb_stepsand epoch configs are set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pd/train/training.py:
- Around line 132-137: The PD training class fails to initialize self.model_prob
which can raise AttributeError later; add a default initialization
(self.model_prob = None) alongside the other instance attributes (e.g., near
where self.num_model is set after self.model_keys) so downstream code that reads
self.model_prob is safe when num_epoch_dict or multi-task paths don’t set it.
- Around line 558-568: After the resume/model-loading block complete and after
the model wrapper initialization, add a defensive fallback that if
self.multi_task is True and self.model_prob is None then set self.model_prob =
resolve_model_prob(self.model_keys, training_params.get("model_prob"),
training_data); locate the resume logic and model wrapper init in training.py
and insert this check there (mirroring the PT trainer) so multi-task resumes
always have model_prob resolved.
🧹 Nitpick comments (6)
deepmd/tf/entrypoints/train.py (2)
257-279: Consider extractingcompute_total_numb_batchto a shared utility module.This function is now duplicated across four files:
deepmd/tf/entrypoints/train.pydeepmd/tf/entrypoints/change_bias.pydeepmd/pt/train/training.pydeepmd/pd/train/training.pyThe implementations are nearly identical but have minor inconsistencies. For example, the PT/PD versions include additional validations for
nbatches(1D shape, empty, finite, non-negative checks) that this TF version lacks.Consider extracting this to a shared utility (e.g.,
deepmd/utils/training.py) to ensure consistent behavior and reduce maintenance burden.
282-283: Consider warning when bothnumb_stepsandnum_epochare provided.When
numb_stepsis set,num_epochis silently ignored. This could confuse users who accidentally configure both. A warning would help clarify which value takes precedence.💡 Suggested enhancement
stop_batch = training_params.get("numb_steps") num_epoch = training_params.get("num_epoch") + if stop_batch is not None and num_epoch is not None: + log.warning( + "Both training.numb_steps and training.num_epoch are set; " + "using numb_steps=%d and ignoring num_epoch.", + stop_batch, + ) if stop_batch is None:source/tests/pt/test_sampler.py (1)
360-361: Replaceraise AssertionErrorwithself.fail()for unittest consistency.Using
raise AssertionErrordirectly works butself.fail()is the idiomatic way to signal test failures in unittest, providing better integration with test runners.💡 Suggested fix
if model_counts_epoch[0] == 0 or model_counts_epoch[1] == 0: - raise AssertionError("Model sampling produced zero counts for a task.") + self.fail("Model sampling produced zero counts for a task.")deepmd/pt/train/training.py (1)
300-302: Redundant conditional check.The check
if model_key in model_prob_configon line 301 is always true at this point because line 295-299 already raises an error if any key is missing frommodel_prob_config.♻️ Suggested simplification
for ii, model_key in enumerate(model_keys): - if model_key in model_prob_config: - model_prob[ii] = float(model_prob_config[model_key]) + model_prob[ii] = float(model_prob_config[model_key])deepmd/pd/train/training.py (1)
243-268: Minor difference: Missing rank check for info log.The PT version (lines 304-308) only logs the default model_prob message on rank 0:
if self.rank == 0: log.info("training.model_prob is not set or empty; ...")The PD version doesn't have this guard, which could cause duplicate log messages in distributed training.
💡 Suggested fix for consistency
else: + if self.rank == 0: + log.info( + "training.model_prob is not set or empty; defaulting to the " + "number of systems per task." + ) for ii, model_key in enumerate(model_keys): model_prob[ii] = float(len(model_training_data[model_key]))doc/train/multi-task-training.md (1)
82-91: Documentation accurately reflects the new feature.Minor nit: Line 86 uses "pretrained" while line 102 uses "pre-trained". Consider using consistent spelling throughout the document.
Suggested fix for consistency
- where a data-rich pretrained model is jointly trained with a data-scarce downstream task. + where a data-rich pre-trained model is jointly trained with a data-scarce downstream task.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pd/train/training.pydeepmd/pt/train/training.pydeepmd/tf/entrypoints/train.pydeepmd/utils/argcheck.pydoc/train/multi-task-training.mdsource/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.
Applied to files:
source/tests/pt/test_sampler.py
🧬 Code graph analysis (3)
deepmd/tf/entrypoints/train.py (3)
deepmd/pd/train/training.py (1)
compute_total_numb_batch(211-241)deepmd/pt/train/training.py (1)
compute_total_numb_batch(253-286)deepmd/tf/entrypoints/change_bias.py (1)
compute_total_numb_batch(190-212)
source/tests/pt/test_sampler.py (2)
deepmd/pt/utils/dataloader.py (3)
DpLoaderSet(65-243)get_weighted_sampler(266-290)get_sampler_from_params(293-306)deepmd/utils/data_system.py (1)
set_sys_probs(413-434)
deepmd/pd/train/training.py (1)
deepmd/pt/train/training.py (3)
compute_total_numb_batch(253-286)resolve_model_prob(288-318)resolve_model_prob_from_epochs(320-367)
🪛 LanguageTool
doc/train/multi-task-training.md
[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...
(EN_WORD_COHERENCY)
🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/train.py
260-260: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
266-266: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
273-273: Avoid specifying long messages outside the exception class
(TRY003)
276-278: Avoid specifying long messages outside the exception class
(TRY003)
286-288: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
292-294: Avoid specifying long messages outside the exception class
(TRY003)
299-299: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/utils/argcheck.py
3525-3527: Avoid specifying long messages outside the exception class
(TRY003)
3530-3532: Avoid specifying long messages outside the exception class
(TRY003)
3534-3536: Avoid specifying long messages outside the exception class
(TRY003)
3539-3541: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_sampler.py
100-102: Avoid specifying long messages outside the exception class
(TRY003)
104-107: Avoid specifying long messages outside the exception class
(TRY003)
361-361: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/train/training.py
259-259: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
265-265: Avoid specifying long messages outside the exception class
(TRY003)
268-268: Avoid specifying long messages outside the exception class
(TRY003)
272-272: Avoid specifying long messages outside the exception class
(TRY003)
274-274: Avoid specifying long messages outside the exception class
(TRY003)
276-276: Avoid specifying long messages outside the exception class
(TRY003)
278-278: Avoid specifying long messages outside the exception class
(TRY003)
280-280: Avoid specifying long messages outside the exception class
(TRY003)
283-285: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Avoid specifying long messages outside the exception class
(TRY003)
314-314: Avoid specifying long messages outside the exception class
(TRY003)
317-317: Avoid specifying long messages outside the exception class
(TRY003)
326-328: Avoid specifying long messages outside the exception class
(TRY003)
331-334: Avoid specifying long messages outside the exception class
(TRY003)
339-341: Avoid specifying long messages outside the exception class
(TRY003)
344-346: Avoid specifying long messages outside the exception class
(TRY003)
350-350: Avoid specifying long messages outside the exception class
(TRY003)
352-352: Avoid specifying long messages outside the exception class
(TRY003)
354-354: Avoid specifying long messages outside the exception class
(TRY003)
356-356: Avoid specifying long messages outside the exception class
(TRY003)
360-360: Avoid specifying long messages outside the exception class
(TRY003)
562-564: Avoid specifying long messages outside the exception class
(TRY003)
566-566: Avoid specifying long messages outside the exception class
(TRY003)
568-570: Avoid specifying long messages outside the exception class
(TRY003)
609-612: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
214-214: Avoid specifying long messages outside the exception class
(TRY003)
216-216: Avoid specifying long messages outside the exception class
(TRY003)
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
227-227: Avoid specifying long messages outside the exception class
(TRY003)
229-229: Avoid specifying long messages outside the exception class
(TRY003)
231-231: Avoid specifying long messages outside the exception class
(TRY003)
233-233: Avoid specifying long messages outside the exception class
(TRY003)
235-235: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
252-254: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Avoid specifying long messages outside the exception class
(TRY003)
276-278: Avoid specifying long messages outside the exception class
(TRY003)
281-284: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
294-296: Avoid specifying long messages outside the exception class
(TRY003)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
302-302: Avoid specifying long messages outside the exception class
(TRY003)
304-304: Avoid specifying long messages outside the exception class
(TRY003)
306-306: Avoid specifying long messages outside the exception class
(TRY003)
310-310: Avoid specifying long messages outside the exception class
(TRY003)
513-515: Avoid specifying long messages outside the exception class
(TRY003)
517-517: Avoid specifying long messages outside the exception class
(TRY003)
519-521: Avoid specifying long messages outside the exception class
(TRY003)
560-563: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (17)
deepmd/tf/entrypoints/train.py (1)
281-306: Logic for computingstop_batchfromnum_epochis correct.The implementation correctly:
- Retrieves both
numb_stepsandnum_epochfrom training params- Falls back to epoch-based calculation when
numb_stepsis absent- Validates that
num_epochis positive- Computes
stop_batch = ceil(num_epoch * total_numb_batch)- Logs the computed values for transparency
Note: The check on line 298-299 (
total_numb_batch <= 0) is defensive but should never trigger given thatcompute_total_numb_batchalready validates that probabilities sum to a positive value and at least one probability is positive.source/tests/pt/test_sampler.py (4)
31-43: Serial pool shim is a clean solution for CI compatibility.The
_SerialPoolclass provides a minimal context-manager-compatible implementation that avoids multiprocessing issues in restricted CI environments. This is a pragmatic approach.
95-108: Test helper_compute_total_numb_batchhas stricter requirements than production code.The test helper rejects zero probabilities (line 103-107), while the production
compute_total_numb_batchfilters them out withvalid = probs > 0.0and processes only positive entries. This means the test helper cannot be used to verify behavior with zero-probability systems.This is acceptable since the test docstring notes it's "a simplified test-only variant," but be aware this limits test coverage for edge cases with zero probabilities.
244-290: Single-task sampling stability test is well-structured.The test validates that:
- Epoch-derived steps produce sampling frequencies close to target probabilities (within
atol=0.1)- Reproducibility is maintained when using the same random seed
The
atol=0.1tolerance is reasonable for stochastic sampling with the number of steps used.
419-519: End-to-end epoch calculation test provides good coverage.The test validates the mathematical properties of the epoch-based step calculation:
- Each task completes at least its target epochs
- Rounding scales all tasks consistently
The assertions on lines 495-518 effectively verify that the scheduling algorithm maintains proportional fairness across tasks.
deepmd/pt/train/training.py (5)
142-142: Good initialization ofmodel_probtoNone.Initializing
model_probtoNoneallows for deferred resolution based on configuration (either fromnum_epoch_dictormodel_probconfig), which is resolved later in the initialization flow.
253-286:compute_total_numb_batchhas comprehensive validation.This implementation includes all necessary validations:
- Shape checks (1D arrays)
- Empty checks
- Finiteness checks
- Non-negativity checks
- Positive sum check
- Shape matching between arrays
This is the most complete version compared to the TF implementation.
320-367:resolve_model_prob_from_epochscorrectly implements per-task epoch scheduling.The function:
- Validates epoch targets are positive and finite
- Computes per-task steps as
per_task_total * epoch_targets- Derives
model_probfrom the relative step contributions- Returns the normalized probabilities, total steps, and per-task step map
This aligns with the multi-task training documentation describing
num_epoch_dictbehavior.
552-617: Training step resolution logic correctly handles all configuration scenarios.The implementation properly handles:
- Single-task: Computes
num_stepsfromnum_epochwhennum_stepsis not provided- Multi-task with
num_epoch_dict: Derives bothmodel_probandnum_stepsfrom per-task epoch targets- Multi-task without
num_epoch_dict: Requires explicitnum_stepsand resolvesmodel_probfrom config or data sizesThe logging on lines 572-577 and 599-606 provides good visibility into computed values.
826-831: Fallbackmodel_probresolution handles resuming scenarios.This fallback ensures
model_probis resolved even when the primary initialization path didn't set it (e.g., when resuming withoutnum_epoch_dict). This is a necessary safeguard.deepmd/pd/train/training.py (3)
211-241:compute_total_numb_batchimplementation matches PT version.The validation logic is consistent with the PT implementation, ensuring uniform behavior across backends.
502-528: Single-task epoch-based step calculation is correct.The logic correctly:
- Computes
total_numb_batchfrom sampler weights and data indices- Validates
num_epochis positive when used- Computes
num_steps = ceil(num_epoch * total_numb_batch)- Logs the computed values
529-568: Multi-task training step resolution is correctly implemented.The implementation handles both scenarios:
- With
num_epoch_dict: Derivesmodel_probandnum_stepsfrom per-task epoch targets- Without
num_epoch_dict: Requires explicitnum_stepsand resolvesmodel_probfrom configThe per-task total calculation loop (lines 530-539) correctly aggregates batch counts for each model.
deepmd/utils/argcheck.py (4)
3216-3244: LGTM! Documentation is comprehensive.The documentation strings clearly explain the semantics, formulas for computing total steps, and mutual exclusivity constraints for all training duration parameters.
3290-3295: LGTM!Documentation correctly clarifies that
model_probis only used whennum_epoch_dictis not set.
3315-3346: LGTM!The new arguments are well-defined with appropriate types, aliases for backward compatibility, and proper optional flags.
3516-3552: Mutual exclusivity checks are correct; model key validation is properly handled at runtime.The validation logic correctly enforces mutual exclusivity between
num_epoch,num_epoch_dict,num_steps, andmodel_probparameters. The constraint that "all model keys must be specified in the dictionary" (fromdoc_num_epoch_dict) is intentionally validated at runtime intraining.pyrather than inargcheck.py, since model keys are only known when training initialization occurs. This is appropriate separation of concerns:argcheck.pyhandles configuration-level mutual exclusivity, whiletraining.pyenforces semantic constraints with full context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pd/train/training.py (1)
259-261: Missing informational log message for default model_prob.The PT implementation logs when
model_probdefaults to system counts, but this PD version doesn't. Consider adding the log for consistency:♻️ Suggested fix
else: + if self.rank == 0: + log.info( + "training.model_prob is not set or empty; defaulting to the " + "number of systems per task." + ) for ii, model_key in enumerate(model_keys): model_prob[ii] = float(len(model_training_data[model_key]))
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/train/training.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pd/train/training.py
🧬 Code graph analysis (1)
deepmd/pd/train/training.py (3)
deepmd/tf/entrypoints/train.py (1)
compute_total_numb_batch(257-279)deepmd/tf/entrypoints/change_bias.py (1)
compute_total_numb_batch(190-212)deepmd/pt/train/training.py (3)
compute_total_numb_batch(253-286)resolve_model_prob(288-318)resolve_model_prob_from_epochs(320-367)
🪛 Ruff (0.14.10)
deepmd/pd/train/training.py
215-215: Avoid specifying long messages outside the exception class
(TRY003)
217-217: Avoid specifying long messages outside the exception class
(TRY003)
219-219: Avoid specifying long messages outside the exception class
(TRY003)
221-221: Avoid specifying long messages outside the exception class
(TRY003)
224-224: Avoid specifying long messages outside the exception class
(TRY003)
228-228: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Avoid specifying long messages outside the exception class
(TRY003)
232-232: Avoid specifying long messages outside the exception class
(TRY003)
234-234: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
239-241: Avoid specifying long messages outside the exception class
(TRY003)
253-255: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
265-265: Avoid specifying long messages outside the exception class
(TRY003)
268-268: Avoid specifying long messages outside the exception class
(TRY003)
277-279: Avoid specifying long messages outside the exception class
(TRY003)
282-285: Avoid specifying long messages outside the exception class
(TRY003)
290-292: Avoid specifying long messages outside the exception class
(TRY003)
295-297: Avoid specifying long messages outside the exception class
(TRY003)
301-301: Avoid specifying long messages outside the exception class
(TRY003)
303-303: Avoid specifying long messages outside the exception class
(TRY003)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
311-311: Avoid specifying long messages outside the exception class
(TRY003)
514-516: Avoid specifying long messages outside the exception class
(TRY003)
518-518: Avoid specifying long messages outside the exception class
(TRY003)
520-522: Avoid specifying long messages outside the exception class
(TRY003)
561-564: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
🔇 Additional comments (6)
deepmd/pd/train/training.py (6)
133-138: LGTM!The initialization changes correctly use
.get()for optional parameters (numb_steps,num_epoch,num_epoch_dict) and initializemodel_probtoNonefor deferred resolution. This aligns with the PT implementation pattern.
212-242: LGTM!The
compute_total_numb_batchfunction correctly validates input arrays and computes the total number of batches needed to cover all systems according to their sampling probabilities. The implementation matches the PT version.
271-318: LGTM!The
resolve_model_prob_from_epochsfunction correctly derives model probabilities and total training steps from per-task epoch targets. The validation logic is thorough and the implementation matches the PT version.
503-529: LGTM!The single-task flow correctly derives
num_stepsfromnum_epochwhen not explicitly provided. The validation ensures bothnum_epochandtotal_numb_batchare positive before computation, with appropriate logging.
530-569: LGTM!The multi-task flow correctly handles both
num_epoch_dict-based scheduling (deriving probabilities from epoch targets) and explicitnumb_stepswithmodel_probconfiguration. The per-task total computation and logging are appropriate.
756-762: This is intentional defensive code, not dead code.The fallback at lines 756-762 serves an important purpose for resumed multi-task training. During resuming (lines 609-755),
self.wrapper.set_state_dict()loads model weights from the checkpoint but does not persistself.model_prob, which is a computed variable. The fallback ensures that ifmodel_probis not set during the initial multi-task initialization (lines 541-569), it gets recomputed before training continues. This is necessary for consistent multi-task sampling behavior across training resumptions.
iProzd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are common methods in pt and pd which can be merged.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/dpmodel/utils/training_utils.py (1)
103-105: Minor: Redundant key check.The check
if model_key in model_prob_configat line 104 is always true because lines 98-102 already validate that all model_keys exist in model_prob_config and raise if any are missing.♻️ Suggested simplification
for ii, model_key in enumerate(model_keys): - if model_key in model_prob_config: - model_prob[ii] = float(model_prob_config[model_key]) + model_prob[ii] = float(model_prob_config[model_key])source/tests/pt/test_sampler.py (1)
443-444: Remove unused variables.
probs_1andprobs_2are computed but never used in this test. They appear to be leftover from copying test setup code.♻️ Suggested fix
sampler_2 = pt_dataloader.get_sampler_from_params( dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"} ) - probs_1 = self._normalize_probs(np.asarray(sampler_1.weights)) - probs_2 = self._normalize_probs(np.asarray(sampler_2.weights)) # === Step 2. Compute per-task total_numb_batch ===
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
deepmd/dpmodel/utils/__init__.pydeepmd/dpmodel/utils/training_utils.pydeepmd/pd/train/training.pydeepmd/pt/train/training.pydeepmd/tf/entrypoints/change_bias.pydeepmd/tf/entrypoints/train.pysource/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pd/train/training.py
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.
Applied to files:
source/tests/pt/test_sampler.py
🧬 Code graph analysis (4)
deepmd/tf/entrypoints/change_bias.py (1)
deepmd/dpmodel/utils/training_utils.py (1)
compute_total_numb_batch(12-64)
deepmd/dpmodel/utils/__init__.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
compute_total_numb_batch(12-64)resolve_model_prob(67-121)resolve_model_prob_from_epochs(124-188)
deepmd/pt/train/training.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
compute_total_numb_batch(12-64)resolve_model_prob(67-121)resolve_model_prob_from_epochs(124-188)
deepmd/tf/entrypoints/train.py (1)
deepmd/dpmodel/utils/training_utils.py (1)
compute_total_numb_batch(12-64)
🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/change_bias.py
198-198: Avoid specifying long messages outside the exception class
(TRY003)
201-201: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/dpmodel/utils/training_utils.py
37-37: Avoid specifying long messages outside the exception class
(TRY003)
39-39: Avoid specifying long messages outside the exception class
(TRY003)
41-41: Avoid specifying long messages outside the exception class
(TRY003)
43-43: Avoid specifying long messages outside the exception class
(TRY003)
46-46: Avoid specifying long messages outside the exception class
(TRY003)
50-50: Avoid specifying long messages outside the exception class
(TRY003)
52-52: Avoid specifying long messages outside the exception class
(TRY003)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
56-56: Avoid specifying long messages outside the exception class
(TRY003)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
61-63: Avoid specifying long messages outside the exception class
(TRY003)
100-102: Avoid specifying long messages outside the exception class
(TRY003)
115-115: Avoid specifying long messages outside the exception class
(TRY003)
117-117: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
154-156: Avoid specifying long messages outside the exception class
(TRY003)
161-163: Avoid specifying long messages outside the exception class
(TRY003)
166-168: Avoid specifying long messages outside the exception class
(TRY003)
172-172: Avoid specifying long messages outside the exception class
(TRY003)
174-174: Avoid specifying long messages outside the exception class
(TRY003)
176-176: Avoid specifying long messages outside the exception class
(TRY003)
178-178: Avoid specifying long messages outside the exception class
(TRY003)
182-182: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/train/training.py
451-453: Avoid specifying long messages outside the exception class
(TRY003)
455-455: Avoid specifying long messages outside the exception class
(TRY003)
457-459: Avoid specifying long messages outside the exception class
(TRY003)
498-501: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/tf/entrypoints/train.py
265-267: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
271-273: Avoid specifying long messages outside the exception class
(TRY003)
278-278: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
411-413: Avoid specifying long messages outside the exception class
(TRY003)
415-415: Avoid specifying long messages outside the exception class
(TRY003)
417-419: Avoid specifying long messages outside the exception class
(TRY003)
458-461: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_sampler.py
353-353: Avoid specifying long messages outside the exception class
(TRY003)
443-443: Local variable probs_1 is assigned to but never used
Remove assignment to unused variable probs_1
(F841)
444-444: Local variable probs_2 is assigned to but never used
Remove assignment to unused variable probs_2
(F841)
🔇 Additional comments (16)
deepmd/dpmodel/utils/__init__.py (1)
39-43: LGTM!The new utility functions are correctly imported and exported. The
__all__list maintains alphabetical ordering, consistent with the existing style.Also applies to: 57-57, 69-70
deepmd/dpmodel/utils/training_utils.py (2)
12-64: LGTM!The function correctly computes the minimum number of training steps required to ensure each data system with positive probability is sampled at least its target batch count times in expectation. Input validation is thorough.
124-188: LGTM!The epoch-based probability resolution logic is sound. By making
model_probproportional toper_task_steps(epochs × batches), tasks with more target steps are sampled proportionally, ensuring each task trains for approximately its target number of epochs.deepmd/tf/entrypoints/change_bias.py (1)
193-210: LGTM!The stop_batch computation logic correctly mirrors the train.py implementation, computing steps from
num_epochwhennumb_stepsis not provided. The fallback tostop_batch = 0is appropriate for the change_bias entrypoint where actual training may not occur.deepmd/tf/entrypoints/train.py (1)
260-285: LGTM!The epoch-based step derivation is correctly implemented with proper validation. The requirement for
train_datawhen usingnum_epochis appropriately enforced, and the logic is consistent with the PT and PD backends.deepmd/pt/train/training.py (4)
26-30: LGTM!Imports and attribute initialization are correctly set up. Using
.get()for optional parameters enables graceful handling when keys are absent.Also applies to: 147-147, 150-152
441-466: LGTM!Single-task step resolution correctly derives
num_stepsfromnum_epochwhen not explicitly provided, with appropriate validation and logging.
467-507: LGTM!Multi-task step resolution properly handles both
num_epoch_dict(epoch-based) and explicitnumb_stepspaths. The delegation to utility functions (resolve_model_prob_from_epochs,resolve_model_prob) keeps the trainer code clean.
716-722: LGTM!The fallback ensures
model_probis always initialized for multi-task training, covering edge cases like resuming from checkpoints where the epoch-based path wasn't taken during initialization.source/tests/pt/test_sampler.py (4)
34-52: LGTM!The
_SerialPoolshim is a clean solution to avoid SemLock/CUDA initialization failures in restricted CI environments. The context manager protocol andmapmethod are correctly implemented.Also applies to: 79-80
82-151: LGTM!The helper methods are well-structured for testing sampling behavior.
_sample_multitask_countscorrectly mirrors the trainer's task selection logic usingdp_random.choice.
232-280: LGTM!Comprehensive test validating that epoch-derived step counts produce equivalent sampling behavior to explicit step counts. The use of identical random seeds ensures deterministic comparison.
282-409: LGTM!Thorough multi-task sampling stability test that validates both model selection and within-task system selection frequencies match expected probabilities.
deepmd/pd/train/training.py (3)
33-37: LGTM!Imports and initialization are consistent with the PT backend implementation, maintaining code alignment across backends.
Also applies to: 138-138, 141-143
400-467: LGTM!Step resolution logic correctly mirrors the PT implementation, with appropriate adjustments for Paddle's
BatchSamplerstructure. Both single-task and multi-task paths are properly handled.
654-661: LGTM!Fallback
model_probresolution is consistent with the PT backend, ensuring multi-task training always has task probabilities initialized.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@deepmd/utils/argcheck.py`:
- Around line 3540-3566: training_extra_check currently only enforces mutual
exclusivity but allows neither "numb_steps" nor "num_epoch" to be set in
single-task mode; update the function (training_extra_check) so that when
multi_task is False and both num_steps and num_epoch are None it raises a
ValueError indicating one of training.numb_steps or training.num_epoch must be
provided; keep existing mutual-exclusion checks and ensure the error message
names the symbols ("numb_steps"/"num_epoch") to aid users.
- Around line 3547-3560: In the multi_task branch add a validation that mirrors
single-task: if multi_task is True and both num_steps is None and num_epoch_dict
is empty/None, raise a ValueError; specifically inside the existing multi_task
handling block (alongside the checks referencing num_epoch, num_epoch_dict,
num_steps, and model_prob) add a check like: if not num_epoch_dict and num_steps
is None: raise ValueError("training.num_epoch_dict or training.num_step must be
set in multi-task mode."). Ensure you treat an empty dict as invalid (use a
falsy check) and keep the error message descriptive.
In `@source/tests/pt/test_sampler.py`:
- Around line 443-444: test_num_epoch_dict computes probs_1 and probs_2 via
self._normalize_probs(np.asarray(sampler_1.weights)) and
self._normalize_probs(np.asarray(sampler_2.weights)) but never uses them; remove
the unused variables or add assertions that validate the normalized
distributions. Update the test_num_epoch_dict function to either delete the
lines that assign probs_1 and probs_2, or replace them with meaningful checks
(e.g., assert sums equal 1, non-negativity, or expected values) referencing
_normalize_probs, sampler_1.weights and sampler_2.weights so the normalization
is actually verified.
🧹 Nitpick comments (2)
doc/train/multi-task-training.md (1)
82-90: Good documentation for the newnum_epoch_dictparameter.The explanation of the mutual exclusivity and the formula for computing
model_proband total steps is clear and useful.One minor inconsistency: Line 86 uses "pretrained" while line 101 uses "pre-trained". Consider using consistent spelling throughout.
deepmd/dpmodel/utils/training_utils.py (1)
97-105: Consider logging whenmodel_prob_configis provided.Currently, there's an info log when
model_prob_configis empty (defaulting to system counts), but no log when user-specified config is used. For consistency and debugging, consider logging the user-specified probabilities as well.♻️ Optional enhancement
if model_prob_config: missing = [k for k in model_keys if k not in model_prob_config] if missing: raise ValueError( f"training.model_prob must specify all tasks; missing: {missing}" ) + if rank == 0: + log.info("Using user-specified model_prob: %s", model_prob_config) for ii, model_key in enumerate(model_keys):
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
deepmd/dpmodel/utils/__init__.pydeepmd/dpmodel/utils/training_utils.pydeepmd/pd/train/training.pydeepmd/pt/train/training.pydeepmd/tf/entrypoints/change_bias.pydeepmd/tf/entrypoints/train.pydeepmd/utils/argcheck.pydoc/train/multi-task-training.mdsource/tests/pt/test_sampler.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/utils/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.
Applied to files:
source/tests/pt/test_sampler.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pd/train/training.py
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
compute_total_numb_batch(12-64)resolve_model_prob(67-121)resolve_model_prob_from_epochs(124-188)
deepmd/pd/train/training.py (3)
deepmd/dpmodel/utils/training_utils.py (3)
compute_total_numb_batch(12-64)resolve_model_prob(67-121)resolve_model_prob_from_epochs(124-188)deepmd/pd/utils/utils.py (3)
to_numpy_array(230-230)to_numpy_array(234-234)to_numpy_array(237-254)deepmd/dpmodel/common.py (1)
to_numpy_array(106-128)
🪛 LanguageTool
doc/train/multi-task-training.md
[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...
(EN_WORD_COHERENCY)
🪛 Ruff (0.14.11)
source/tests/pt/test_sampler.py
353-353: Avoid specifying long messages outside the exception class
(TRY003)
443-443: Local variable probs_1 is assigned to but never used
Remove assignment to unused variable probs_1
(F841)
444-444: Local variable probs_2 is assigned to but never used
Remove assignment to unused variable probs_2
(F841)
deepmd/pt/train/training.py
451-453: Avoid specifying long messages outside the exception class
(TRY003)
455-455: Avoid specifying long messages outside the exception class
(TRY003)
457-459: Avoid specifying long messages outside the exception class
(TRY003)
498-501: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/utils/argcheck.py
3549-3551: Avoid specifying long messages outside the exception class
(TRY003)
3554-3556: Avoid specifying long messages outside the exception class
(TRY003)
3558-3560: Avoid specifying long messages outside the exception class
(TRY003)
3563-3565: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pd/train/training.py
406-408: Avoid specifying long messages outside the exception class
(TRY003)
410-410: Avoid specifying long messages outside the exception class
(TRY003)
412-414: Avoid specifying long messages outside the exception class
(TRY003)
453-456: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/dpmodel/utils/training_utils.py
37-37: Avoid specifying long messages outside the exception class
(TRY003)
39-39: Avoid specifying long messages outside the exception class
(TRY003)
41-41: Avoid specifying long messages outside the exception class
(TRY003)
43-43: Avoid specifying long messages outside the exception class
(TRY003)
46-46: Avoid specifying long messages outside the exception class
(TRY003)
50-50: Avoid specifying long messages outside the exception class
(TRY003)
52-52: Avoid specifying long messages outside the exception class
(TRY003)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
56-56: Avoid specifying long messages outside the exception class
(TRY003)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
61-63: Avoid specifying long messages outside the exception class
(TRY003)
100-102: Avoid specifying long messages outside the exception class
(TRY003)
115-115: Avoid specifying long messages outside the exception class
(TRY003)
117-117: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
154-156: Avoid specifying long messages outside the exception class
(TRY003)
161-163: Avoid specifying long messages outside the exception class
(TRY003)
166-168: Avoid specifying long messages outside the exception class
(TRY003)
172-172: Avoid specifying long messages outside the exception class
(TRY003)
174-174: Avoid specifying long messages outside the exception class
(TRY003)
176-176: Avoid specifying long messages outside the exception class
(TRY003)
178-178: Avoid specifying long messages outside the exception class
(TRY003)
182-182: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/tf/entrypoints/change_bias.py
198-198: Avoid specifying long messages outside the exception class
(TRY003)
201-201: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/tf/entrypoints/train.py
265-267: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
271-273: Avoid specifying long messages outside the exception class
(TRY003)
278-278: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (21)
source/tests/pt/test_sampler.py (3)
16-22: Module-qualified import pattern looks good.Using
import deepmd.pt.utils.dataloader as pt_dataloaderand accessing components via the module alias (e.g.,pt_dataloader.DpLoaderSet) is cleaner than mixingimportandfromimports. This addresses the static analysis warning about mixed import styles.
34-45: Good CI workaround with_SerialPool.The serial pool shim avoids SemLock/CUDA initialization issues in restricted CI environments while maintaining the same interface as the multiprocessing Pool.
411-513: Well-structured epoch calculation test.The
test_num_epoch_dicttest thoroughly validates:
- Per-task
total_numb_batchcomputation- Step calculation from epoch targets
- Rounding consistency across tasks (scale_0 ≈ scale_1)
The assertions properly verify that rounding up (
ceil) doesn't reduce expected epochs and scales all tasks consistently.deepmd/utils/argcheck.py (1)
3216-3244: Documentation fornumb_stepsandnum_epochis clear and comprehensive.The docstrings properly explain:
- The mutual exclusivity between parameters
- The formula for computing steps from epochs
- The accepted aliases
- Backend-specific behavior (single-task vs multi-task)
deepmd/dpmodel/utils/training_utils.py (3)
12-64: Well-implementedcompute_total_numb_batchwith thorough validation.The function correctly computes
ceil(max(n_batches[i] / prob[i]))over positive probability entries. The comprehensive input validation (1D arrays, finite values, non-negative, positive sum) will provide clear error messages for misconfigurations.
67-121:resolve_model_probcorrectly handles both user-specified and default probabilities.Good design choices:
- Requires all model keys when user specifies
model_prob_config- Defaults to system counts when config is empty
- Logs the fallback behavior only on rank 0 (avoids log spam in distributed training)
124-188:resolve_model_prob_from_epochscorrectly implements the documented formula.The implementation matches the documented behavior:
model_prob[i] = num_epoch_dict[i] * per_task_total[i] / sum_j(...)num_steps = ceil(sum_i(num_epoch_dict[i] * per_task_total[i]))The validation for positive epoch values and matching array lengths is appropriate.
deepmd/tf/entrypoints/train.py (2)
260-285: Correct implementation ofnum_epochsupport for TensorFlow backend.The logic properly:
- Falls back to
num_epochwhennumb_stepsis not set- Validates
num_epochis positive- Requires
train_datafor computingtotal_numb_batch- Logs the derived values for transparency
The formula
stop_batch = ceil(num_epoch * total_numb_batch)matches the documentation.
263-267: Runtime validation is necessary but ideally caught earlier.This runtime check for missing both
numb_stepsandnum_epochis correct, but with the suggested fix totraining_extra_checkin argcheck.py, this would be caught during config validation instead. The runtime check provides defense-in-depth.deepmd/tf/entrypoints/change_bias.py (2)
18-20: LGTM!The import of
compute_total_numb_batchfrom the shared utilities module aligns with the PR's goal of consolidating common epoch-based scheduling logic.
193-210: LGTM!The logic to derive
stop_batchfromnum_epochis correctly implemented:
- Validates
num_epoch > 0andtotal_numb_batch > 0before computation- Uses
np.ceilfor conservative step counting- Maintains backward compatibility with the
stop_batch = 0fallbackThis aligns well with the epoch-based scheduling introduced in the PT/PD trainers.
deepmd/pt/train/training.py (5)
26-30: LGTM!The imports from
deepmd.dpmodel.utilsproperly bring in the shared utilities for epoch-based scheduling. This consolidation aligns with the PR objective of sharing common methods between PT and PD code paths.
147-147: LGTM!The explicit initialization of
self.model_prob = Noneprevents potentialAttributeErrorin edge cases wheremodel_probmight be accessed before being set, addressing the concern from past reviews.
150-152: LGTM!The configuration extraction correctly uses
.get()to allow bothnumb_stepsandnum_epoch/num_epoch_dictto be optional, enabling the new epoch-based scheduling while maintaining backward compatibility.
732-738: LGTM!The defensive fallback ensures
model_probis always resolved for multi-task training, even in edge cases where the earlier resolution paths might not have been executed (e.g., certain resuming scenarios). This addresses the concern raised in past reviews.
441-508: LGTM.The training steps resolution logic is well-structured:
- Single-task: derives
num_stepsfromnum_epoch × total_numb_batch- Multi-task with
num_epoch_dict: usesresolve_model_prob_from_epochsto compute bothmodel_probandnum_steps- Multi-task without
num_epoch_dict: requires explicitnum_stepsand resolvesmodel_probseparatelyThe shared utilities in
training_utils.pyhandle edge case validation (empty weights, non-finite values, missing keys).training_data.indexcorrectly provides the batch count per system (populated fromlen(system_dataloader)in the dataloader), which is passed tocompute_total_numb_batchas expected.deepmd/pd/train/training.py (5)
33-38: LGTM!The imports correctly bring in the shared utilities from
deepmd.dpmodel.utils. Note theBaseLRimport on line 33 is separate from the utility imports, which is fine syntactically.
136-136: LGTM!The explicit initialization of
self.model_prob = Noneaddresses the past review concern and maintains consistency with the PT trainer implementation.
139-141: LGTM!The configuration extraction mirrors the PT implementation, correctly using
.get()for the new epoch-based configuration options.
395-462: LGTM!The resolution logic correctly mirrors the PT implementation while adapting to Paddle's API:
- Uses
batch_sampler.sampler.weights(Paddle) instead ofsampler.weights(PyTorch)- Same validation flow and utility usage
- Consistent error messages and logging
This maintains parity between the PT and PD trainers for the epoch-based scheduling feature.
649-656: LGTM!The defensive fallback for
model_probresolution addresses the past review concern and maintains consistency with the PT trainer. This ensures multi-task training always has valid model probabilities, even in edge cases involving model resumption.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.