Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 11, 2026

Summary by CodeRabbit

  • New Features

    • Per-model epoch targets (num_epoch, num_epoch_dict) with automatic derivation of total training steps and per-task scheduling; multi-task probabilities now resolved from these targets when provided.
  • Bug Fixes

    • Stricter validation and clearer logging for epoch/step resolution and sampler totals; robust handling when deriving stop steps from epoch counts.
  • Documentation

    • Updated training arguments, aliases, and mutual‑exclusion rules; multi‑task docs explain epoch-based scheduling.
  • Tests

    • New deterministic sampling tests for single- and multi-task scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 11, 2026 15:56
@dosubot dosubot bot added the new feature label Jan 11, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for specifying training duration using num_epoch instead of requiring explicit numb_steps. The implementation computes num_steps automatically based on dataset size and sampling probabilities, supporting both single-task and multi-task training modes.

Changes:

  • Made numb_steps parameter optional and added num_epoch parameter in training configuration
  • Implemented automatic computation of training steps from epoch count and dataset characteristics
  • Added comprehensive tests for sampling stability in both single-task and multi-task scenarios

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
source/tests/pt/test_sampler.py Added test infrastructure (_SerialPool, helper methods) and new tests for single-task and multi-task sampling stability to validate the num_epoch functionality
deepmd/utils/argcheck.py Updated training parameter documentation and made numb_steps optional while adding num_epoch parameter definition
deepmd/pt/train/training.py Implemented compute_total_numb_batch and resolve_model_prob functions, added logic to compute num_steps from num_epoch when numb_steps is not provided

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 11, 2026

Warning

Rate limit exceeded

@OutisLi has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 3 minutes and 44 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between aedf98e and ad6be1c.

📒 Files selected for processing (2)
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_sampler.py

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds utilities and flow to derive per-task probabilities and total training steps from per-task batch counts and epoch targets; exposes num_epoch/num_epoch_dict and model_prob state; updates PT/PD trainers, TF entrypoints, arg parsing, tests, docs, and dpmodel utils to support epoch-based multi-task scheduling.

Changes

Cohort / File(s) Summary
PT trainer (probability & step computation)
deepmd/pt/train/training.py
Adds model_prob, num_epoch, num_epoch_dict; uses compute_total_numb_batch, resolve_model_prob, resolve_model_prob_from_epochs; derives num_steps from num_epoch when numb_steps absent; multi-task model_prob resolution deferred until per-task totals known; logging/validation added.
PD trainer (mirrored logic)
deepmd/pd/train/training.py
Mirrors PT changes: optional numb_steps, num_epoch, num_epoch_dict; computes per-task totals and resolves model_prob/num_steps via new utilities; validation/logging added.
DP model utils (new utilities)
deepmd/dpmodel/utils/training_utils.py, deepmd/dpmodel/utils/__init__.py
New functions compute_total_numb_batch, resolve_model_prob, resolve_model_prob_from_epochs; input validation, normalization, and exports added.
Arg parsing / docs
deepmd/utils/argcheck.py, doc/train/multi-task-training.md
Makes numb_steps optional, adds num_epoch and num_epoch_dict arguments and docs; adds training_extra_check enforcing mutual exclusivity (single vs multi-task rules) and related validation; documents num_epoch_dict semantics and model_prob precedence.
TF entrypoints (stop_batch derivation)
deepmd/tf/entrypoints/train.py, deepmd/tf/entrypoints/change_bias.py
When numb_steps missing and num_epoch present, compute total_numb_batch via compute_total_numb_batch and set stop_batch = ceil(num_epoch * total_numb_batch) with validation and logging; preserves prior behavior when numb_steps provided.
Tests & CI shim (sampling stability)
source/tests/pt/test_sampler.py
Adds serial Pool shim for CI, deterministic sampling helpers, and tests verifying sampling stability and num_epoch_dict behavior (test_sampling_stability_single_task, test_sampling_stability_multi_task, test_num_epoch_dict); switches to module-qualified dataloader API usage.

Sequence Diagram(s)

sequenceDiagram
  participant Config
  participant Trainer
  participant Utils
  participant Sampler
  participant DataLoader

  Config->>Trainer: provide training params (model_keys, numb_steps/num_epoch/num_epoch_dict, model_prob)
  Trainer->>Utils: compute per-task totals (compute_total_numb_batch)
  alt num_epoch_dict provided
    Trainer->>Utils: resolve model_prob & per-task steps (resolve_model_prob_from_epochs)
  else multi-task and num_epoch_dict not provided
    Trainer->>Utils: resolve model_prob from config/counts (resolve_model_prob)
  else single-task and num_epoch provided
    Trainer->>Utils: compute total_numb_batch -> derive num_steps
  end
  Trainer->>Trainer: finalize num_steps and model_prob
  Trainer->>Sampler: initialize sampler with resolved model_prob and num_steps
  Sampler->>DataLoader: request batches according to probabilities
  DataLoader->>Trainer: yield batches into training loop
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main feature: enabling num_epoch as a mechanism to set num_steps, which aligns with the PR's core objective of introducing epoch-based step configuration.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 275-292: resolve_model_prob currently treats an empty dict as
"provided" and produces an all-zero vector; update resolve_model_prob so that if
model_prob_config is None or empty (len == 0) it falls back to using
model_training_data lengths for initial weights, otherwise read values from
model_prob_config, and validate each provided value is finite and non-negative
(use numpy.isfinite and value >= 0), raising a clear ValueError naming the
offending model_key for invalid entries; after building the weight vector (from
config or lengths) ensure sum_prob > 0 and then return the normalized
probabilities; reference resolve_model_prob, model_keys, model_prob_config,
model_training_data, and sum_prob in your changes.

In @deepmd/utils/argcheck.py:
- Around line 3298-3305: The argcheck change made "numb_steps" optional but the
Paddle trainer still uses direct dict access training_params["numb_steps"],
causing runtime KeyError; either (A) update the Paddle training code that
references training_params["numb_steps"] to use
training_params.get("numb_steps") and implement the same fallback logic used by
the PyTorch/TensorFlow trainers (fall back to num_epoch or a default) or (B)
enforce cross-parameter validation in argcheck so that at least one of
"numb_steps" or "num_epoch" is required (add a validator that raises if both are
missing); pick one approach and apply it consistently to the Paddle trainer and
argcheck to avoid the KeyError.
🧹 Nitpick comments (3)
deepmd/pt/train/training.py (2)

476-525: Deriving num_steps from num_epoch is clean and well-guarded.

Checks for missing both, non-positive num_epoch, and non-positive total_numb_batch are good. The warning path when both are set matches the doc (“numb_steps takes precedence”).
Minor: self.model_prob gets resolved here in multi-task, making the later “Get model prob” block (Line 733+) redundant.


733-739: Redundant model_prob resolution block (likely dead).

After the new “Resolve training steps” section, self.model_prob should always be set for multi-task (or an exception is thrown), so this guard probably never runs now. Consider removing to avoid diverging behaviors in the future.

source/tests/pt/test_sampler.py (1)

84-155: Helper methods are fine; consider mirroring production validation in _normalize_probs / _compute_total_numb_batch.

Right now tests assume no zero-prob entries (division-by-zero would explode). If you expect users might set zero weights, you could align the test helper with compute_total_numb_batch()’s “filter probs>0” behavior to avoid brittle tests.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 82a5f32 and 94149a9.

📒 Files selected for processing (3)
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_sampler.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/pt/utils/dataloader.py (1)
  • DpLoaderSet (65-243)
deepmd/dpmodel/common.py (1)
  • to_numpy_array (106-128)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py

258-258: Avoid specifying long messages outside the exception class

(TRY003)


260-260: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


270-272: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)


505-507: Avoid specifying long messages outside the exception class

(TRY003)


509-509: Avoid specifying long messages outside the exception class

(TRY003)


511-511: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_sampler.py

338-338: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (42)
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Agent
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Analyze (python)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (4)
deepmd/utils/argcheck.py (1)

3216-3225: Docs are good, but double-check multi-task wording matches implementation.

total_numb_batch is described as a model-prob-weighted sum across tasks; that matches the new training.py logic (ceil of weighted sum of per-task totals).
Minor: consider explicitly stating what happens when training.model_prob is omitted (uniform? proportional to dataset size?), since training.py now needs a deterministic fallback.

deepmd/pt/train/training.py (1)

252-274: compute_total_numb_batch() logic looks right; consider behavior when sampler has no weights.

The validation + ceil(max(nbatches/prob)) matches the intended formula. One potential fragility: callers assume self.training_dataloader.sampler.weights exists; if a non-weighted sampler is used, this will crash before your validation runs. If that’s impossible by construction, OK—otherwise add a fallback (e.g., uniform weights).

source/tests/pt/test_sampler.py (2)

36-55: CI-friendly multiprocessing workaround is reasonable; keep it tightly scoped.

The _SerialPool + pytest.MonkeyPatch approach should avoid SemLock/CUDA init issues and is correctly undone in tearDown().

Also applies to: 81-83


227-395: Good coverage for determinism and multi-task distributions.

These tests exercise:

  • single-task: empirical SID distribution ~= target probs, and exact equality between two runs when RNG seeds + num_steps match
  • multi-task: model selection frequencies ~= model_prob, and per-task SID distributions ~= per-task sampler probs

This is a solid regression net for the new num_epoch -> num_steps derivation behavior.

@codecov
Copy link

codecov bot commented Jan 11, 2026

Codecov Report

❌ Patch coverage is 47.77778% with 94 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.83%. Comparing base (567c5ba) to head (ad6be1c).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/training_utils.py 41.86% 50 Missing ⚠️
deepmd/pt/train/training.py 55.17% 13 Missing ⚠️
deepmd/pd/train/training.py 58.62% 12 Missing ⚠️
deepmd/tf/entrypoints/train.py 35.29% 11 Missing ⚠️
deepmd/tf/entrypoints/change_bias.py 42.85% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5148      +/-   ##
==========================================
- Coverage   81.95%   81.83%   -0.12%     
==========================================
  Files         713      715       +2     
  Lines       72985    73472     +487     
  Branches     3617     3616       -1     
==========================================
+ Hits        59812    60129     +317     
- Misses      12010    12182     +172     
+ Partials     1163     1161       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 275-296: resolve_model_prob currently treats an empty dict ({})
the same as None because it checks "if model_prob_config:" which makes "{}" fall
back to using len(model_training_data) instead of an explicit uniform
distribution; change the conditional to check identity against None (if
model_prob_config is None:) so only a true None triggers the
proportional-to-training-data behavior, and add a branch for an explicit empty
dict (or empty mapping) to assign uniform weights across model_keys (or log that
{} means uniform) while preserving validations in resolve_model_prob.
- Around line 252-274: In compute_total_numb_batch, explicitly validate
sampler_weights for finiteness and negativity: after converting to weights
(np.asarray(...)), raise a ValueError if not np.all(np.isfinite(weights)) or if
np.any(weights < 0); also check that weight_sum is finite
(np.isfinite(weight_sum)) before using it. Keep the existing checks (1D,
non-empty, positive sum) but ensure non-finite and negative entries are rejected
up-front to avoid producing misleading totals.

In @deepmd/utils/argcheck.py:
- Around line 3216-3232: The docstring for doc_numb_steps currently claims it
takes precedence over num_epoch globally; update it to scope that behavior to
the PyTorch backend (or otherwise clarify backend-specific behavior). Modify the
doc_numb_steps string to prepend or include "(Supported Backend: PyTorch)" and
rephrase the precedence sentence to indicate this rule applies only for the
PyTorch backend rather than universally, and ensure doc_num_epoch remains
consistent about backend-specific computations if needed.
🧹 Nitpick comments (2)
deepmd/pt/train/training.py (1)

736-742: Optional: remove the now-redundant model_prob re-resolution block.

Since self.model_prob is already set during step resolution for multi-task, this should be dead code in normal flows.

source/tests/pt/test_sampler.py (1)

91-109: Test helper is intentionally simplified; keep it aligned if training math changes again.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 94149a9 and 7111056.

📒 Files selected for processing (3)
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.

Applied to files:

  • source/tests/pt/test_sampler.py
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py

258-258: Avoid specifying long messages outside the exception class

(TRY003)


260-260: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


270-272: Avoid specifying long messages outside the exception class

(TRY003)


289-289: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


509-511: Avoid specifying long messages outside the exception class

(TRY003)


513-513: Avoid specifying long messages outside the exception class

(TRY003)


515-515: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_sampler.py

100-102: Avoid specifying long messages outside the exception class

(TRY003)


104-107: Avoid specifying long messages outside the exception class

(TRY003)


355-355: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (39)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
🔇 Additional comments (3)
deepmd/pt/train/training.py (1)

142-147: Good: num_steps/num_epoch are now parsed via .get(), and model_prob is initialized.

deepmd/utils/argcheck.py (1)

3304-3316: Schema updates look consistent: numb_steps optional w/ aliases, plus new num_epoch.

source/tests/pt/test_sampler.py (1)

31-78: _SerialPool correctly implements the required Pool API.

The monkeypatch is safe and complete. _SerialPool provides all methods used by pt_dataloader (context manager protocol and map()), and the implementation matches the behavior of multiprocessing.dummy.Pool.

@OutisLi OutisLi changed the title feat(pt): use num_epoch to set num_steps feat: use num_epoch to set num_steps Jan 12, 2026
@OutisLi OutisLi requested review from Chengqian-Zhang, iProzd and njzjz and removed request for Chengqian-Zhang January 12, 2026 05:02
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 252-305: In resolve_model_prob: when model_prob_config is
provided, do not silently leave unspecified model_keys at zero; validate that
every model_key is present in model_prob_config (or explicitly allowed), and if
any are missing raise a ValueError (include the missing key names) instead of
filling zeros; implement this check before building model_prob (using model_keys
and model_prob_config) and only fall back to using model_training_data counts
when model_prob_config is None or empty, otherwise require completeness.
🧹 Nitpick comments (3)
deepmd/tf/entrypoints/change_bias.py (1)

190-231: Validate nbatches too (and reconsider stop_batch=0 fallback).

compute_total_numb_batch() validates probabilities, but nbatches can still be non-1D / non-finite / negative, which can surface as opaque NumPy failures. Also, silently defaulting stop_batch to 0 if neither training.numb_steps nor training.num_epoch is set may break trainer.build() if it assumes a positive step count.

Proposed tightening for nbatches validation (local, minimal)
 def compute_total_numb_batch(nbatches, sys_probs) -> int:
     weights = np.asarray(sys_probs, dtype=np.float64)
     if weights.ndim != 1:
         raise ValueError("Sampler probabilities must be 1D.")
     if weights.size == 0:
         raise ValueError("Sampler probabilities are empty.")
     if not np.all(np.isfinite(weights)):
         raise ValueError("Sampler probabilities must be finite.")
     if np.any(weights < 0.0):
         raise ValueError("Sampler probabilities must be non-negative.")
     weight_sum = float(np.sum(weights))
     if weight_sum <= 0.0:
         raise ValueError("Sampler probabilities must sum to a positive value.")
     probs = weights / weight_sum
-    nbatches = np.asarray(nbatches, dtype=np.float64)
+    nbatches = np.asarray(nbatches, dtype=np.float64)
+    if nbatches.ndim != 1:
+        raise ValueError("Number of batches must be 1D.")
+    if nbatches.size == 0:
+        raise ValueError("Number of batches is empty.")
+    if not np.all(np.isfinite(nbatches)):
+        raise ValueError("Number of batches must be finite.")
+    if np.any(nbatches < 0.0):
+        raise ValueError("Number of batches must be non-negative.")
     if nbatches.shape[0] != probs.shape[0]:
         raise ValueError("Number of batches and sampler probabilities must match.")
     valid = probs > 0.0
     if not np.any(valid):
         raise ValueError(
             "Sampler probabilities must contain at least one positive entry."
         )
     return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
deepmd/tf/entrypoints/train.py (1)

257-313: Good precedence + errors; add nbatches validation (and avoid helper drift).

The numb_steps/num_epoch resolution is solid and the warning on both-set is helpful. Same suggestion as elsewhere: validate nbatches (1D/finite/non-negative) to prevent odd NumPy errors, and consider centralizing compute_total_numb_batch to avoid cross-backend drift.

deepmd/utils/argcheck.py (1)

3308-3320: Schema change looks right; consider aligning “neither set” behavior across entrypoints.

numb_steps is now optional and num_epoch is introduced, but enforcement differs across entrypoints (some raise; change_bias.py falls back to 0). Consider making this consistent (either always raise, or explicitly document the one-off).

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7111056 and 6955014.

📒 Files selected for processing (5)
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
  • deepmd/tf/entrypoints/change_bias.py
  • deepmd/tf/entrypoints/train.py
  • deepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (3)
deepmd/pd/train/training.py (2)
deepmd/pt/train/training.py (2)
  • compute_total_numb_batch (252-277)
  • resolve_model_prob (279-304)
deepmd/tf/entrypoints/train.py (1)
  • compute_total_numb_batch (257-279)
deepmd/pt/train/training.py (3)
deepmd/pd/train/training.py (2)
  • compute_total_numb_batch (210-232)
  • resolve_model_prob (234-254)
deepmd/tf/entrypoints/change_bias.py (1)
  • compute_total_numb_batch (190-212)
deepmd/tf/entrypoints/train.py (1)
  • compute_total_numb_batch (257-279)
deepmd/tf/entrypoints/train.py (3)
deepmd/pd/train/training.py (1)
  • compute_total_numb_batch (210-232)
deepmd/pt/train/training.py (1)
  • compute_total_numb_batch (252-277)
deepmd/tf/entrypoints/change_bias.py (1)
  • compute_total_numb_batch (190-212)
🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/change_bias.py

193-193: Avoid specifying long messages outside the exception class

(TRY003)


195-195: Avoid specifying long messages outside the exception class

(TRY003)


197-197: Avoid specifying long messages outside the exception class

(TRY003)


199-199: Avoid specifying long messages outside the exception class

(TRY003)


202-202: Avoid specifying long messages outside the exception class

(TRY003)


206-206: Avoid specifying long messages outside the exception class

(TRY003)


209-211: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


222-222: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

213-213: Avoid specifying long messages outside the exception class

(TRY003)


215-215: Avoid specifying long messages outside the exception class

(TRY003)


217-217: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


222-222: Avoid specifying long messages outside the exception class

(TRY003)


226-226: Avoid specifying long messages outside the exception class

(TRY003)


229-231: Avoid specifying long messages outside the exception class

(TRY003)


248-248: Avoid specifying long messages outside the exception class

(TRY003)


250-250: Avoid specifying long messages outside the exception class

(TRY003)


253-253: Avoid specifying long messages outside the exception class

(TRY003)


469-471: Avoid specifying long messages outside the exception class

(TRY003)


473-473: Avoid specifying long messages outside the exception class

(TRY003)


475-475: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/train/training.py

258-258: Avoid specifying long messages outside the exception class

(TRY003)


260-260: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


271-271: Avoid specifying long messages outside the exception class

(TRY003)


274-276: Avoid specifying long messages outside the exception class

(TRY003)


298-298: Avoid specifying long messages outside the exception class

(TRY003)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


303-303: Avoid specifying long messages outside the exception class

(TRY003)


518-520: Avoid specifying long messages outside the exception class

(TRY003)


522-522: Avoid specifying long messages outside the exception class

(TRY003)


524-524: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/tf/entrypoints/train.py

260-260: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


266-266: Avoid specifying long messages outside the exception class

(TRY003)


269-269: Avoid specifying long messages outside the exception class

(TRY003)


273-273: Avoid specifying long messages outside the exception class

(TRY003)


276-278: Avoid specifying long messages outside the exception class

(TRY003)


286-288: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)


292-294: Avoid specifying long messages outside the exception class

(TRY003)


299-299: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
🔇 Additional comments (10)
deepmd/tf/entrypoints/train.py (1)

15-16: NumPy import looks fine.

deepmd/pt/train/training.py (3)

142-147: Reading numb_steps/num_epoch via .get() is a good compat move.


746-752: Guarded model_prob resolution is fine (keeps later paths consistent).


489-538: No action needed — .index is correctly per-system batch counts.

training_data.index contains batch counts per system (populated via len(system_dataloader) for each system), not sample indices. The sampler weights are also per-system, ensuring shape consistency. The math in compute_total_numb_batch() is correct.

deepmd/utils/argcheck.py (2)

3216-3232: Docs for numb_steps precedence + num_epoch derivation look consistent.


3278-3282: model_prob default behavior is now clear; good.

deepmd/pd/train/training.py (4)

135-136: LGTM!

The change to use .get() for both numb_steps and num_epoch allows either to be optionally specified, with validation deferred to later in the initialization.


210-232: LGTM!

The compute_total_numb_batch function correctly validates sampler weights and computes the total number of batches needed for one epoch. The implementation is consistent with the reference in deepmd/pt/train/training.py.


467-488: LGTM!

The logic for deriving num_steps from num_epoch is well-implemented with proper validation:

  • Clear error when neither numb_steps nor num_epoch is provided
  • Validation that num_epoch and total_numb_batch are positive
  • Ceiling function ensures at least the specified number of epochs are trained
  • Helpful warning when both options are set

443-446: No issues found. The code is correct.

training_data.index is a list containing the number of batches for each system dataloader (populated via len(system_dataloader) at dataloader construction). This is exactly what compute_total_numb_batch expects—an array-like of batch counts to be paired with sampler_weights. The shape check in the function (line 216) validates that the number of batch counts matches the number of weights, confirming the usage is safe and correct.

@Chengqian-Zhang
Copy link
Collaborator

Furthermore, I recommend to implement a new option: in multi-task mode, num_epoch is applied to specific model branches. It is much useful in multi-task fine-tuning scenario, which means we jointly train a data-rich dataset(such as OMat24, OC20) and a data-scarce downstream dataset. In this situation, we only care about the performance and training epochs of data-scarce downstream task. However, I think it is OK to implement it in future PR.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

279-309: Improved validation over other backends, with a minor redundancy.

The missing-key validation (lines 286-290) and info logging (lines 295-299) are valuable enhancements compared to the PD/TF implementations. However, after the missing-key check, the condition on line 292 (if model_key in model_prob_config) is always true and can be simplified.

♻️ Optional: Remove redundant conditional
             for ii, model_key in enumerate(model_keys):
-                if model_key in model_prob_config:
-                    model_prob[ii] = float(model_prob_config[model_key])
+                model_prob[ii] = float(model_prob_config[model_key])
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6955014 and d1871b5.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/tf/entrypoints/train.py (1)
  • compute_total_numb_batch (257-279)
deepmd/pd/train/training.py (2)
  • compute_total_numb_batch (210-232)
  • resolve_model_prob (234-254)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py

258-258: Avoid specifying long messages outside the exception class

(TRY003)


260-260: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


271-271: Avoid specifying long messages outside the exception class

(TRY003)


274-276: Avoid specifying long messages outside the exception class

(TRY003)


288-290: Avoid specifying long messages outside the exception class

(TRY003)


303-303: Avoid specifying long messages outside the exception class

(TRY003)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


308-308: Avoid specifying long messages outside the exception class

(TRY003)


523-525: Avoid specifying long messages outside the exception class

(TRY003)


527-527: Avoid specifying long messages outside the exception class

(TRY003)


529-529: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (25)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
🔇 Additional comments (5)
deepmd/pt/train/training.py (5)

142-142: LGTM!

The model_prob attribute initialization is appropriate for deferred computation in multi-task scenarios.


145-146: LGTM!

Using .get() for both parameters correctly handles the case where either numb_steps or num_epoch may be omitted from the configuration.


252-277: LGTM!

The compute_total_numb_batch function correctly validates inputs and computes the total batch count needed to cover all systems given their sampling probabilities. The implementation aligns with the analogous functions in the PD and TF training modules.


494-542: LGTM!

The training steps resolution logic correctly handles both single-task and multi-task modes:

  • Single-task computes total_numb_batch from sampler weights
  • Multi-task computes a weighted sum of per-task batch counts using model probabilities
  • Validation ensures exactly one of numb_steps or num_epoch is effectively used
  • The warning when both are set provides clear feedback about precedence

751-756: LGTM!

The guarded model_prob resolution provides a defensive fallback. Given the current control flow, self.model_prob is always set at line 513 for multi-task mode, so this branch serves as a safety net for potential future refactoring or edge cases.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In @deepmd/pd/train/training.py:
- Around line 235-256: In resolve_model_prob, ensure model_prob_config cannot
silently omit tasks: if model_prob_config is provided, validate that every
model_key in model_keys is present (use set(model_keys) -
set(model_prob_config.keys()) to find missing keys) and raise a ValueError
listing the missing keys (mention model_keys and model_prob_config in the
message) instead of filling zeros; keep the existing finite/negative/sum checks
and normalization so callers like num_epoch_dict won't end up dividing by zero
or skipping tasks.
- Around line 211-234: compute_total_numb_batch currently only validates
sampler_weights; validate the numb_batches input similarly: convert numb_batches
to a 1D numpy array (nbatches = np.asarray(numb_batches, dtype=np.float64)) and
check it is 1D and the same length as weights, contains only finite values
(np.isfinite), and contains no negative entries; additionally ensure there is at
least one positive nbatch where the corresponding prob > 0 before computing the
ceil(max(...)) to avoid NaNs/infs and raise ValueError with clear messages
referencing numb_batches/nbatches when checks fail.

In @deepmd/pt/train/training.py:
- Around line 550-554: The log is zipping self.model_keys (all tasks) with
per_task_steps (only tasks where epoch_value is not None), causing wrong
pairings; change the code that builds per_task_steps to also collect the
corresponding task keys (e.g., keys_for_steps) using the same filter, then use
zip(keys_for_steps, per_task_steps) (or build a dict from those paired lists) in
the logging call so each model key is matched to its correct ceil'd step value;
update any variable names referenced (self.model_keys, per_task_steps)
accordingly.
- Around line 538-545: The loop computing per_task_steps can raise
ZeroDivisionError if self.model_prob[ii] == 0 and ValueError if per_task_steps
stays empty; update the logic in the block that iterates over
self.model_keys/num_epoch_dict to skip entries where epoch_value is None or
self.model_prob[ii] == 0 (i.e., do not perform steps_i = epoch_value *
per_task_total[ii] / self.model_prob[ii] for those cases), collect valid steps_i
into per_task_steps, and after the loop set self.num_steps =
int(np.ceil(np.max(per_task_steps))) only if per_task_steps is non-empty else
set self.num_steps = 0 (or another safe default); refer to symbols:
per_task_steps, self.model_keys, self.num_epoch_dict, self.model_prob,
per_task_total, and self.num_steps.
🧹 Nitpick comments (4)
source/tests/pt/test_sampler.py (2)

46-78: Prefer self.addCleanup(self._monkeypatch.undo) to avoid patch leaks if setUp fails mid-way.
Right now, a failure after Line 49 but before tearDown registration can leave pt_dataloader.Pool patched for subsequent tests.

Proposed fix
 def setUp(self) -> None:
     self._monkeypatch = pytest.MonkeyPatch()
     # Avoid SemLock/CUDA initialization failures in restricted CI by forcing a serial pool.
     self._monkeypatch.setattr(pt_dataloader, "Pool", _SerialPool)
+    self.addCleanup(self._monkeypatch.undo)
@@
 def tearDown(self) -> None:
-    self._monkeypatch.undo()
+    self._monkeypatch.undo()

244-519: Sampling stability tests may be a bit tolerance-sensitive; consider atol by num_steps (optional).
Using a fixed atol=0.1 can be flaky if CI runtime changes cause small effective sample counts; scaling tolerance by sqrt(p(1-p)/n) (or increasing num_epoch) tends to be more robust.

deepmd/pt/train/training.py (1)

292-294: Minor redundancy in conditional check.

The check if model_key in model_prob_config at line 293 is redundant since missing keys are already caught and raised at lines 287-291. The loop at line 292 will only execute when all keys are present.

♻️ Suggested simplification
             for ii, model_key in enumerate(model_keys):
-                if model_key in model_prob_config:
-                    model_prob[ii] = float(model_prob_config[model_key])
+                model_prob[ii] = float(model_prob_config[model_key])
doc/train/multi-task-training.md (1)

84-91: Good documentation, minor terminology inconsistency.

The documentation clearly explains the num_epoch_dict feature and its precedence over num_epoch.

One minor issue: Line 86 uses "pretrained" while line 102 uses "pre-trained". For consistency within the document, consider using the same variant throughout.

📝 Suggested fix
-  where a data-rich pretrained model is jointly trained with a data-scarce downstream task.
+  where a data-rich pre-trained model is jointly trained with a data-scarce downstream task.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7c9813b and 8cef984.

📒 Files selected for processing (5)
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • doc/train/multi-task-training.md
  • source/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.

Applied to files:

  • source/tests/pt/test_sampler.py
🧬 Code graph analysis (1)
source/tests/pt/test_sampler.py (2)
deepmd/pt/utils/dataloader.py (3)
  • DpLoaderSet (65-243)
  • get_weighted_sampler (266-290)
  • get_sampler_from_params (293-306)
deepmd/utils/data_system.py (1)
  • set_sys_probs (413-434)
🪛 LanguageTool
doc/train/multi-task-training.md

[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...

(EN_WORD_COHERENCY)

🪛 Ruff (0.14.10)
source/tests/pt/test_sampler.py

100-102: Avoid specifying long messages outside the exception class

(TRY003)


104-107: Avoid specifying long messages outside the exception class

(TRY003)


355-355: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/train/training.py

259-259: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


265-265: Avoid specifying long messages outside the exception class

(TRY003)


268-268: Avoid specifying long messages outside the exception class

(TRY003)


272-272: Avoid specifying long messages outside the exception class

(TRY003)


275-277: Avoid specifying long messages outside the exception class

(TRY003)


289-291: Avoid specifying long messages outside the exception class

(TRY003)


304-304: Avoid specifying long messages outside the exception class

(TRY003)


306-306: Avoid specifying long messages outside the exception class

(TRY003)


309-309: Avoid specifying long messages outside the exception class

(TRY003)


527-529: Avoid specifying long messages outside the exception class

(TRY003)


534-536: Avoid specifying long messages outside the exception class

(TRY003)


552-552: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


557-560: Avoid specifying long messages outside the exception class

(TRY003)


563-563: Avoid specifying long messages outside the exception class

(TRY003)


565-567: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

214-214: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


227-227: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


249-249: Avoid specifying long messages outside the exception class

(TRY003)


251-251: Avoid specifying long messages outside the exception class

(TRY003)


254-254: Avoid specifying long messages outside the exception class

(TRY003)


473-475: Avoid specifying long messages outside the exception class

(TRY003)


480-482: Avoid specifying long messages outside the exception class

(TRY003)


498-498: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


503-506: Avoid specifying long messages outside the exception class

(TRY003)


509-509: Avoid specifying long messages outside the exception class

(TRY003)


511-513: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (32)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (8)
deepmd/pd/train/training.py (1)

135-137: Nice: numb_steps is now optional and epoch-based scheduling is surfaced (num_epoch, num_epoch_dict).
This makes the Trainer consistent with the new arg schema and enables the requested behavior.

source/tests/pt/test_sampler.py (1)

79-164: Test helpers look consistent with the new num_epoch → num_steps math.
The “derived steps == explicit steps” assertions are a good regression guard for deterministic sampling.

deepmd/utils/argcheck.py (2)

3216-3252: Docs clearly define precedence and the multi-task meaning of “epoch”.
This should reduce user confusion around fractional epochs and weighted sampling.


3321-3348: The numb_steps aliases (stop_batch, num_steps) are safe—no collision risk detected.

The training code correctly reads the resolved numb_steps field from training_params.get("numb_steps") after alias resolution in the Argument parser. The num_steps alias does not conflict with the internal self.num_steps attribute, which stores the retrieved value separately from the config field name.

deepmd/pt/train/training.py (4)

142-147: LGTM!

The initialization of model_prob and the use of .get() for optional parameters is appropriate.


253-278: LGTM!

The compute_total_numb_batch function correctly computes the total number of batches needed to ensure each system completes at least one epoch. The validation logic is thorough and the formula max(nbatches[valid] / probs[valid]) ensures proper coverage.


575-582: LGTM!

Appropriate warning when both numb_steps and num_epoch/num_epoch_dict are set, with clear precedence given to numb_steps.


791-796: LGTM!

The guard prevents redundant re-resolution of model_prob when it was already computed during the training steps resolution phase.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
deepmd/pd/train/training.py (2)

211-241: LGTM! Comprehensive validation for batch count computation.

The validation logic correctly handles edge cases (empty arrays, non-finite values, zero weights). The formula ceil(max(nbatches[valid] / probs[valid])) correctly computes the minimum number of global steps needed to ensure each system is sampled according to its weight.

Note: This function is duplicated across PT, PD, and TF implementations. Consider extracting to a shared utility module in the future to reduce maintenance burden.


243-268: Consider adding a log message when defaulting to system counts.

The logic is correct. However, the PT implementation (lines 304-308) logs an info message when model_prob_config is not set, informing users that it defaults to the number of systems per task. Adding the same log here would improve consistency and help users understand the behavior.

♻️ Suggested enhancement
             else:
+                log.info(
+                    "training.model_prob is not set or empty; defaulting to the "
+                    "number of systems per task."
+                )
                 for ii, model_key in enumerate(model_keys):
                     model_prob[ii] = float(len(model_training_data[model_key]))
deepmd/pt/train/training.py (1)

804-809: Consider adding a clarifying comment for this defensive block.

In the current flow, self.model_prob is already set at lines 522-526 for multi-task mode, so this block would not execute. However, it serves as defensive code. A brief comment would help future maintainers understand its purpose.

♻️ Suggested enhancement
         # Get model prob for multi-task
+        # Defensive fallback: resolve model_prob if not already set during initialization
+        # (e.g., if future code paths skip the earlier resolution block)
         if self.multi_task and self.model_prob is None:
             self.model_prob = resolve_model_prob(
                 self.model_keys,
                 training_params.get("model_prob"),
                 training_data,
             )
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8cef984 and 750a6d0.

📒 Files selected for processing (2)
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (3)
deepmd/pd/train/training.py (2)
  • compute_total_numb_batch (211-241)
  • resolve_model_prob (243-268)
deepmd/tf/entrypoints/train.py (1)
  • compute_total_numb_batch (257-279)
deepmd/tf/entrypoints/change_bias.py (1)
  • compute_total_numb_batch (190-212)
🪛 Ruff (0.14.10)
deepmd/pt/train/training.py

259-259: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


265-265: Avoid specifying long messages outside the exception class

(TRY003)


268-268: Avoid specifying long messages outside the exception class

(TRY003)


272-272: Avoid specifying long messages outside the exception class

(TRY003)


274-274: Avoid specifying long messages outside the exception class

(TRY003)


276-276: Avoid specifying long messages outside the exception class

(TRY003)


278-278: Avoid specifying long messages outside the exception class

(TRY003)


280-280: Avoid specifying long messages outside the exception class

(TRY003)


283-285: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Avoid specifying long messages outside the exception class

(TRY003)


314-314: Avoid specifying long messages outside the exception class

(TRY003)


317-317: Avoid specifying long messages outside the exception class

(TRY003)


535-537: Avoid specifying long messages outside the exception class

(TRY003)


542-544: Avoid specifying long messages outside the exception class

(TRY003)


551-553: Avoid specifying long messages outside the exception class

(TRY003)


558-560: Avoid specifying long messages outside the exception class

(TRY003)


570-573: Avoid specifying long messages outside the exception class

(TRY003)


576-576: Avoid specifying long messages outside the exception class

(TRY003)


578-580: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

214-214: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


227-227: Avoid specifying long messages outside the exception class

(TRY003)


229-229: Avoid specifying long messages outside the exception class

(TRY003)


231-231: Avoid specifying long messages outside the exception class

(TRY003)


233-233: Avoid specifying long messages outside the exception class

(TRY003)


235-235: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


252-254: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


486-488: Avoid specifying long messages outside the exception class

(TRY003)


493-495: Avoid specifying long messages outside the exception class

(TRY003)


502-504: Avoid specifying long messages outside the exception class

(TRY003)


509-511: Avoid specifying long messages outside the exception class

(TRY003)


521-524: Avoid specifying long messages outside the exception class

(TRY003)


527-527: Avoid specifying long messages outside the exception class

(TRY003)


529-531: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
🔇 Additional comments (8)
deepmd/pd/train/training.py (3)

135-137: LGTM!

Proper use of .get() for optional config parameters, allowing either numb_steps or num_epoch/num_epoch_dict to drive step calculation.


481-546: LGTM! Comprehensive step resolution with proper precedence handling.

The logic correctly:

  1. Prioritizes num_epoch_dict for multi-task (allowing per-task epoch targets)
  2. Falls back to num_epoch for uniform epoch count
  3. Raises clear errors when neither is provided
  4. Guards against division by zero when computing per-task steps
  5. Logs computed values for transparency
  6. Warns when both numb_steps and epoch configs are provided

453-480: No issues found with batch count calculation.

The code correctly uses training_data.index to obtain batch counts per system. Both single-task and multi-task modes properly pass these counts to compute_total_numb_batch, which validates and processes them appropriately.

deepmd/pt/train/training.py (5)

142-142: LGTM!

Good defensive initialization of self.model_prob to None. This ensures the attribute exists before the multi-task resolution block and allows the guarded check at line 804.


145-147: LGTM!

Consistent with PD implementation for reading optional config parameters.


253-286: LGTM!

Comprehensive validation matching the PD implementation. The logic correctly computes the minimum number of global steps needed to cover all batches according to their weights.


288-318: LGTM!

Good enhancement with the info log (lines 304-308) when model_prob_config is not set. This helps users understand the default behavior.


503-595: LGTM!

Consistent implementation with the PD version. The step resolution logic is comprehensive with proper validation, logging, and warning when both numb_steps and epoch configs are set.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pd/train/training.py:
- Around line 132-137: The PD training class fails to initialize self.model_prob
which can raise AttributeError later; add a default initialization
(self.model_prob = None) alongside the other instance attributes (e.g., near
where self.num_model is set after self.model_keys) so downstream code that reads
self.model_prob is safe when num_epoch_dict or multi-task paths don’t set it.
- Around line 558-568: After the resume/model-loading block complete and after
the model wrapper initialization, add a defensive fallback that if
self.multi_task is True and self.model_prob is None then set self.model_prob =
resolve_model_prob(self.model_keys, training_params.get("model_prob"),
training_data); locate the resume logic and model wrapper init in training.py
and insert this check there (mirroring the PT trainer) so multi-task resumes
always have model_prob resolved.
🧹 Nitpick comments (6)
deepmd/tf/entrypoints/train.py (2)

257-279: Consider extracting compute_total_numb_batch to a shared utility module.

This function is now duplicated across four files:

  • deepmd/tf/entrypoints/train.py
  • deepmd/tf/entrypoints/change_bias.py
  • deepmd/pt/train/training.py
  • deepmd/pd/train/training.py

The implementations are nearly identical but have minor inconsistencies. For example, the PT/PD versions include additional validations for nbatches (1D shape, empty, finite, non-negative checks) that this TF version lacks.

Consider extracting this to a shared utility (e.g., deepmd/utils/training.py) to ensure consistent behavior and reduce maintenance burden.


282-283: Consider warning when both numb_steps and num_epoch are provided.

When numb_steps is set, num_epoch is silently ignored. This could confuse users who accidentally configure both. A warning would help clarify which value takes precedence.

💡 Suggested enhancement
     stop_batch = training_params.get("numb_steps")
     num_epoch = training_params.get("num_epoch")
+    if stop_batch is not None and num_epoch is not None:
+        log.warning(
+            "Both training.numb_steps and training.num_epoch are set; "
+            "using numb_steps=%d and ignoring num_epoch.",
+            stop_batch,
+        )
     if stop_batch is None:
source/tests/pt/test_sampler.py (1)

360-361: Replace raise AssertionError with self.fail() for unittest consistency.

Using raise AssertionError directly works but self.fail() is the idiomatic way to signal test failures in unittest, providing better integration with test runners.

💡 Suggested fix
         if model_counts_epoch[0] == 0 or model_counts_epoch[1] == 0:
-            raise AssertionError("Model sampling produced zero counts for a task.")
+            self.fail("Model sampling produced zero counts for a task.")
deepmd/pt/train/training.py (1)

300-302: Redundant conditional check.

The check if model_key in model_prob_config on line 301 is always true at this point because line 295-299 already raises an error if any key is missing from model_prob_config.

♻️ Suggested simplification
             for ii, model_key in enumerate(model_keys):
-                if model_key in model_prob_config:
-                    model_prob[ii] = float(model_prob_config[model_key])
+                model_prob[ii] = float(model_prob_config[model_key])
deepmd/pd/train/training.py (1)

243-268: Minor difference: Missing rank check for info log.

The PT version (lines 304-308) only logs the default model_prob message on rank 0:

if self.rank == 0:
    log.info("training.model_prob is not set or empty; ...")

The PD version doesn't have this guard, which could cause duplicate log messages in distributed training.

💡 Suggested fix for consistency
         else:
+            if self.rank == 0:
+                log.info(
+                    "training.model_prob is not set or empty; defaulting to the "
+                    "number of systems per task."
+                )
             for ii, model_key in enumerate(model_keys):
                 model_prob[ii] = float(len(model_training_data[model_key]))
doc/train/multi-task-training.md (1)

82-91: Documentation accurately reflects the new feature.

Minor nit: Line 86 uses "pretrained" while line 102 uses "pre-trained". Consider using consistent spelling throughout the document.

Suggested fix for consistency
-  where a data-rich pretrained model is jointly trained with a data-scarce downstream task.
+  where a data-rich pre-trained model is jointly trained with a data-scarce downstream task.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 750a6d0 and b309383.

📒 Files selected for processing (6)
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
  • deepmd/tf/entrypoints/train.py
  • deepmd/utils/argcheck.py
  • doc/train/multi-task-training.md
  • source/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.

Applied to files:

  • source/tests/pt/test_sampler.py
🧬 Code graph analysis (3)
deepmd/tf/entrypoints/train.py (3)
deepmd/pd/train/training.py (1)
  • compute_total_numb_batch (211-241)
deepmd/pt/train/training.py (1)
  • compute_total_numb_batch (253-286)
deepmd/tf/entrypoints/change_bias.py (1)
  • compute_total_numb_batch (190-212)
source/tests/pt/test_sampler.py (2)
deepmd/pt/utils/dataloader.py (3)
  • DpLoaderSet (65-243)
  • get_weighted_sampler (266-290)
  • get_sampler_from_params (293-306)
deepmd/utils/data_system.py (1)
  • set_sys_probs (413-434)
deepmd/pd/train/training.py (1)
deepmd/pt/train/training.py (3)
  • compute_total_numb_batch (253-286)
  • resolve_model_prob (288-318)
  • resolve_model_prob_from_epochs (320-367)
🪛 LanguageTool
doc/train/multi-task-training.md

[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...

(EN_WORD_COHERENCY)

🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/train.py

260-260: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


266-266: Avoid specifying long messages outside the exception class

(TRY003)


269-269: Avoid specifying long messages outside the exception class

(TRY003)


273-273: Avoid specifying long messages outside the exception class

(TRY003)


276-278: Avoid specifying long messages outside the exception class

(TRY003)


286-288: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)


292-294: Avoid specifying long messages outside the exception class

(TRY003)


299-299: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/utils/argcheck.py

3525-3527: Avoid specifying long messages outside the exception class

(TRY003)


3530-3532: Avoid specifying long messages outside the exception class

(TRY003)


3534-3536: Avoid specifying long messages outside the exception class

(TRY003)


3539-3541: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_sampler.py

100-102: Avoid specifying long messages outside the exception class

(TRY003)


104-107: Avoid specifying long messages outside the exception class

(TRY003)


361-361: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/train/training.py

259-259: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


265-265: Avoid specifying long messages outside the exception class

(TRY003)


268-268: Avoid specifying long messages outside the exception class

(TRY003)


272-272: Avoid specifying long messages outside the exception class

(TRY003)


274-274: Avoid specifying long messages outside the exception class

(TRY003)


276-276: Avoid specifying long messages outside the exception class

(TRY003)


278-278: Avoid specifying long messages outside the exception class

(TRY003)


280-280: Avoid specifying long messages outside the exception class

(TRY003)


283-285: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Avoid specifying long messages outside the exception class

(TRY003)


314-314: Avoid specifying long messages outside the exception class

(TRY003)


317-317: Avoid specifying long messages outside the exception class

(TRY003)


326-328: Avoid specifying long messages outside the exception class

(TRY003)


331-334: Avoid specifying long messages outside the exception class

(TRY003)


339-341: Avoid specifying long messages outside the exception class

(TRY003)


344-346: Avoid specifying long messages outside the exception class

(TRY003)


350-350: Avoid specifying long messages outside the exception class

(TRY003)


352-352: Avoid specifying long messages outside the exception class

(TRY003)


354-354: Avoid specifying long messages outside the exception class

(TRY003)


356-356: Avoid specifying long messages outside the exception class

(TRY003)


360-360: Avoid specifying long messages outside the exception class

(TRY003)


562-564: Avoid specifying long messages outside the exception class

(TRY003)


566-566: Avoid specifying long messages outside the exception class

(TRY003)


568-570: Avoid specifying long messages outside the exception class

(TRY003)


609-612: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

214-214: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


227-227: Avoid specifying long messages outside the exception class

(TRY003)


229-229: Avoid specifying long messages outside the exception class

(TRY003)


231-231: Avoid specifying long messages outside the exception class

(TRY003)


233-233: Avoid specifying long messages outside the exception class

(TRY003)


235-235: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


252-254: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Avoid specifying long messages outside the exception class

(TRY003)


276-278: Avoid specifying long messages outside the exception class

(TRY003)


281-284: Avoid specifying long messages outside the exception class

(TRY003)


289-291: Avoid specifying long messages outside the exception class

(TRY003)


294-296: Avoid specifying long messages outside the exception class

(TRY003)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


302-302: Avoid specifying long messages outside the exception class

(TRY003)


304-304: Avoid specifying long messages outside the exception class

(TRY003)


306-306: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)


513-515: Avoid specifying long messages outside the exception class

(TRY003)


517-517: Avoid specifying long messages outside the exception class

(TRY003)


519-521: Avoid specifying long messages outside the exception class

(TRY003)


560-563: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (17)
deepmd/tf/entrypoints/train.py (1)

281-306: Logic for computing stop_batch from num_epoch is correct.

The implementation correctly:

  1. Retrieves both numb_steps and num_epoch from training params
  2. Falls back to epoch-based calculation when numb_steps is absent
  3. Validates that num_epoch is positive
  4. Computes stop_batch = ceil(num_epoch * total_numb_batch)
  5. Logs the computed values for transparency

Note: The check on line 298-299 (total_numb_batch <= 0) is defensive but should never trigger given that compute_total_numb_batch already validates that probabilities sum to a positive value and at least one probability is positive.

source/tests/pt/test_sampler.py (4)

31-43: Serial pool shim is a clean solution for CI compatibility.

The _SerialPool class provides a minimal context-manager-compatible implementation that avoids multiprocessing issues in restricted CI environments. This is a pragmatic approach.


95-108: Test helper _compute_total_numb_batch has stricter requirements than production code.

The test helper rejects zero probabilities (line 103-107), while the production compute_total_numb_batch filters them out with valid = probs > 0.0 and processes only positive entries. This means the test helper cannot be used to verify behavior with zero-probability systems.

This is acceptable since the test docstring notes it's "a simplified test-only variant," but be aware this limits test coverage for edge cases with zero probabilities.


244-290: Single-task sampling stability test is well-structured.

The test validates that:

  1. Epoch-derived steps produce sampling frequencies close to target probabilities (within atol=0.1)
  2. Reproducibility is maintained when using the same random seed

The atol=0.1 tolerance is reasonable for stochastic sampling with the number of steps used.


419-519: End-to-end epoch calculation test provides good coverage.

The test validates the mathematical properties of the epoch-based step calculation:

  1. Each task completes at least its target epochs
  2. Rounding scales all tasks consistently

The assertions on lines 495-518 effectively verify that the scheduling algorithm maintains proportional fairness across tasks.

deepmd/pt/train/training.py (5)

142-142: Good initialization of model_prob to None.

Initializing model_prob to None allows for deferred resolution based on configuration (either from num_epoch_dict or model_prob config), which is resolved later in the initialization flow.


253-286: compute_total_numb_batch has comprehensive validation.

This implementation includes all necessary validations:

  • Shape checks (1D arrays)
  • Empty checks
  • Finiteness checks
  • Non-negativity checks
  • Positive sum check
  • Shape matching between arrays

This is the most complete version compared to the TF implementation.


320-367: resolve_model_prob_from_epochs correctly implements per-task epoch scheduling.

The function:

  1. Validates epoch targets are positive and finite
  2. Computes per-task steps as per_task_total * epoch_targets
  3. Derives model_prob from the relative step contributions
  4. Returns the normalized probabilities, total steps, and per-task step map

This aligns with the multi-task training documentation describing num_epoch_dict behavior.


552-617: Training step resolution logic correctly handles all configuration scenarios.

The implementation properly handles:

  1. Single-task: Computes num_steps from num_epoch when num_steps is not provided
  2. Multi-task with num_epoch_dict: Derives both model_prob and num_steps from per-task epoch targets
  3. Multi-task without num_epoch_dict: Requires explicit num_steps and resolves model_prob from config or data sizes

The logging on lines 572-577 and 599-606 provides good visibility into computed values.


826-831: Fallback model_prob resolution handles resuming scenarios.

This fallback ensures model_prob is resolved even when the primary initialization path didn't set it (e.g., when resuming without num_epoch_dict). This is a necessary safeguard.

deepmd/pd/train/training.py (3)

211-241: compute_total_numb_batch implementation matches PT version.

The validation logic is consistent with the PT implementation, ensuring uniform behavior across backends.


502-528: Single-task epoch-based step calculation is correct.

The logic correctly:

  1. Computes total_numb_batch from sampler weights and data indices
  2. Validates num_epoch is positive when used
  3. Computes num_steps = ceil(num_epoch * total_numb_batch)
  4. Logs the computed values

529-568: Multi-task training step resolution is correctly implemented.

The implementation handles both scenarios:

  1. With num_epoch_dict: Derives model_prob and num_steps from per-task epoch targets
  2. Without num_epoch_dict: Requires explicit num_steps and resolves model_prob from config

The per-task total calculation loop (lines 530-539) correctly aggregates batch counts for each model.

deepmd/utils/argcheck.py (4)

3216-3244: LGTM! Documentation is comprehensive.

The documentation strings clearly explain the semantics, formulas for computing total steps, and mutual exclusivity constraints for all training duration parameters.


3290-3295: LGTM!

Documentation correctly clarifies that model_prob is only used when num_epoch_dict is not set.


3315-3346: LGTM!

The new arguments are well-defined with appropriate types, aliases for backward compatibility, and proper optional flags.


3516-3552: Mutual exclusivity checks are correct; model key validation is properly handled at runtime.

The validation logic correctly enforces mutual exclusivity between num_epoch, num_epoch_dict, num_steps, and model_prob parameters. The constraint that "all model keys must be specified in the dictionary" (from doc_num_epoch_dict) is intentionally validated at runtime in training.py rather than in argcheck.py, since model keys are only known when training initialization occurs. This is appropriate separation of concerns: argcheck.py handles configuration-level mutual exclusivity, while training.py enforces semantic constraints with full context.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pd/train/training.py (1)

259-261: Missing informational log message for default model_prob.

The PT implementation logs when model_prob defaults to system counts, but this PD version doesn't. Consider adding the log for consistency:

♻️ Suggested fix
         else:
+            if self.rank == 0:
+                log.info(
+                    "training.model_prob is not set or empty; defaulting to the "
+                    "number of systems per task."
+                )
             for ii, model_key in enumerate(model_keys):
                 model_prob[ii] = float(len(model_training_data[model_key]))
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b309383 and 227bad8.

📒 Files selected for processing (1)
  • deepmd/pd/train/training.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pd/train/training.py
🧬 Code graph analysis (1)
deepmd/pd/train/training.py (3)
deepmd/tf/entrypoints/train.py (1)
  • compute_total_numb_batch (257-279)
deepmd/tf/entrypoints/change_bias.py (1)
  • compute_total_numb_batch (190-212)
deepmd/pt/train/training.py (3)
  • compute_total_numb_batch (253-286)
  • resolve_model_prob (288-318)
  • resolve_model_prob_from_epochs (320-367)
🪛 Ruff (0.14.10)
deepmd/pd/train/training.py

215-215: Avoid specifying long messages outside the exception class

(TRY003)


217-217: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


221-221: Avoid specifying long messages outside the exception class

(TRY003)


224-224: Avoid specifying long messages outside the exception class

(TRY003)


228-228: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Avoid specifying long messages outside the exception class

(TRY003)


232-232: Avoid specifying long messages outside the exception class

(TRY003)


234-234: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Avoid specifying long messages outside the exception class

(TRY003)


239-241: Avoid specifying long messages outside the exception class

(TRY003)


253-255: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


265-265: Avoid specifying long messages outside the exception class

(TRY003)


268-268: Avoid specifying long messages outside the exception class

(TRY003)


277-279: Avoid specifying long messages outside the exception class

(TRY003)


282-285: Avoid specifying long messages outside the exception class

(TRY003)


290-292: Avoid specifying long messages outside the exception class

(TRY003)


295-297: Avoid specifying long messages outside the exception class

(TRY003)


301-301: Avoid specifying long messages outside the exception class

(TRY003)


303-303: Avoid specifying long messages outside the exception class

(TRY003)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


311-311: Avoid specifying long messages outside the exception class

(TRY003)


514-516: Avoid specifying long messages outside the exception class

(TRY003)


518-518: Avoid specifying long messages outside the exception class

(TRY003)


520-522: Avoid specifying long messages outside the exception class

(TRY003)


561-564: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
🔇 Additional comments (6)
deepmd/pd/train/training.py (6)

133-138: LGTM!

The initialization changes correctly use .get() for optional parameters (numb_steps, num_epoch, num_epoch_dict) and initialize model_prob to None for deferred resolution. This aligns with the PT implementation pattern.


212-242: LGTM!

The compute_total_numb_batch function correctly validates input arrays and computes the total number of batches needed to cover all systems according to their sampling probabilities. The implementation matches the PT version.


271-318: LGTM!

The resolve_model_prob_from_epochs function correctly derives model probabilities and total training steps from per-task epoch targets. The validation logic is thorough and the implementation matches the PT version.


503-529: LGTM!

The single-task flow correctly derives num_steps from num_epoch when not explicitly provided. The validation ensures both num_epoch and total_numb_batch are positive before computation, with appropriate logging.


530-569: LGTM!

The multi-task flow correctly handles both num_epoch_dict-based scheduling (deriving probabilities from epoch targets) and explicit numb_steps with model_prob configuration. The per-task total computation and logging are appropriate.


756-762: This is intentional defensive code, not dead code.

The fallback at lines 756-762 serves an important purpose for resumed multi-task training. During resuming (lines 609-755), self.wrapper.set_state_dict() loads model weights from the checkpoint but does not persist self.model_prob, which is a computed variable. The fallback ensures that if model_prob is not set during the initial multi-task initialization (lines 541-569), it gets recomputed before training continues. This is necessary for consistent multi-task sampling behavior across training resumptions.

Copy link
Collaborator

@iProzd iProzd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are common methods in pt and pd which can be merged.

@OutisLi
Copy link
Collaborator Author

OutisLi commented Jan 13, 2026

There are common methods in pt and pd which can be merged.
This requires creating a new file

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
deepmd/dpmodel/utils/training_utils.py (1)

103-105: Minor: Redundant key check.

The check if model_key in model_prob_config at line 104 is always true because lines 98-102 already validate that all model_keys exist in model_prob_config and raise if any are missing.

♻️ Suggested simplification
         for ii, model_key in enumerate(model_keys):
-            if model_key in model_prob_config:
-                model_prob[ii] = float(model_prob_config[model_key])
+            model_prob[ii] = float(model_prob_config[model_key])
source/tests/pt/test_sampler.py (1)

443-444: Remove unused variables.

probs_1 and probs_2 are computed but never used in this test. They appear to be leftover from copying test setup code.

♻️ Suggested fix
         sampler_2 = pt_dataloader.get_sampler_from_params(
             dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
         )
-        probs_1 = self._normalize_probs(np.asarray(sampler_1.weights))
-        probs_2 = self._normalize_probs(np.asarray(sampler_2.weights))

         # === Step 2. Compute per-task total_numb_batch ===
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 227bad8 and a37f0c7.

📒 Files selected for processing (7)
  • deepmd/dpmodel/utils/__init__.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
  • deepmd/tf/entrypoints/change_bias.py
  • deepmd/tf/entrypoints/train.py
  • source/tests/pt/test_sampler.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pd/train/training.py
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.

Applied to files:

  • source/tests/pt/test_sampler.py
🧬 Code graph analysis (4)
deepmd/tf/entrypoints/change_bias.py (1)
deepmd/dpmodel/utils/training_utils.py (1)
  • compute_total_numb_batch (12-64)
deepmd/dpmodel/utils/__init__.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
  • compute_total_numb_batch (12-64)
  • resolve_model_prob (67-121)
  • resolve_model_prob_from_epochs (124-188)
deepmd/pt/train/training.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
  • compute_total_numb_batch (12-64)
  • resolve_model_prob (67-121)
  • resolve_model_prob_from_epochs (124-188)
deepmd/tf/entrypoints/train.py (1)
deepmd/dpmodel/utils/training_utils.py (1)
  • compute_total_numb_batch (12-64)
🪛 Ruff (0.14.10)
deepmd/tf/entrypoints/change_bias.py

198-198: Avoid specifying long messages outside the exception class

(TRY003)


201-201: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/dpmodel/utils/training_utils.py

37-37: Avoid specifying long messages outside the exception class

(TRY003)


39-39: Avoid specifying long messages outside the exception class

(TRY003)


41-41: Avoid specifying long messages outside the exception class

(TRY003)


43-43: Avoid specifying long messages outside the exception class

(TRY003)


46-46: Avoid specifying long messages outside the exception class

(TRY003)


50-50: Avoid specifying long messages outside the exception class

(TRY003)


52-52: Avoid specifying long messages outside the exception class

(TRY003)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


56-56: Avoid specifying long messages outside the exception class

(TRY003)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


61-63: Avoid specifying long messages outside the exception class

(TRY003)


100-102: Avoid specifying long messages outside the exception class

(TRY003)


115-115: Avoid specifying long messages outside the exception class

(TRY003)


117-117: Avoid specifying long messages outside the exception class

(TRY003)


120-120: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


154-156: Avoid specifying long messages outside the exception class

(TRY003)


161-163: Avoid specifying long messages outside the exception class

(TRY003)


166-168: Avoid specifying long messages outside the exception class

(TRY003)


172-172: Avoid specifying long messages outside the exception class

(TRY003)


174-174: Avoid specifying long messages outside the exception class

(TRY003)


176-176: Avoid specifying long messages outside the exception class

(TRY003)


178-178: Avoid specifying long messages outside the exception class

(TRY003)


182-182: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/train/training.py

451-453: Avoid specifying long messages outside the exception class

(TRY003)


455-455: Avoid specifying long messages outside the exception class

(TRY003)


457-459: Avoid specifying long messages outside the exception class

(TRY003)


498-501: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/tf/entrypoints/train.py

265-267: Avoid specifying long messages outside the exception class

(TRY003)


269-269: Avoid specifying long messages outside the exception class

(TRY003)


271-273: Avoid specifying long messages outside the exception class

(TRY003)


278-278: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

411-413: Avoid specifying long messages outside the exception class

(TRY003)


415-415: Avoid specifying long messages outside the exception class

(TRY003)


417-419: Avoid specifying long messages outside the exception class

(TRY003)


458-461: Avoid specifying long messages outside the exception class

(TRY003)

source/tests/pt/test_sampler.py

353-353: Avoid specifying long messages outside the exception class

(TRY003)


443-443: Local variable probs_1 is assigned to but never used

Remove assignment to unused variable probs_1

(F841)


444-444: Local variable probs_2 is assigned to but never used

Remove assignment to unused variable probs_2

(F841)

🔇 Additional comments (16)
deepmd/dpmodel/utils/__init__.py (1)

39-43: LGTM!

The new utility functions are correctly imported and exported. The __all__ list maintains alphabetical ordering, consistent with the existing style.

Also applies to: 57-57, 69-70

deepmd/dpmodel/utils/training_utils.py (2)

12-64: LGTM!

The function correctly computes the minimum number of training steps required to ensure each data system with positive probability is sampled at least its target batch count times in expectation. Input validation is thorough.


124-188: LGTM!

The epoch-based probability resolution logic is sound. By making model_prob proportional to per_task_steps (epochs × batches), tasks with more target steps are sampled proportionally, ensuring each task trains for approximately its target number of epochs.

deepmd/tf/entrypoints/change_bias.py (1)

193-210: LGTM!

The stop_batch computation logic correctly mirrors the train.py implementation, computing steps from num_epoch when numb_steps is not provided. The fallback to stop_batch = 0 is appropriate for the change_bias entrypoint where actual training may not occur.

deepmd/tf/entrypoints/train.py (1)

260-285: LGTM!

The epoch-based step derivation is correctly implemented with proper validation. The requirement for train_data when using num_epoch is appropriately enforced, and the logic is consistent with the PT and PD backends.

deepmd/pt/train/training.py (4)

26-30: LGTM!

Imports and attribute initialization are correctly set up. Using .get() for optional parameters enables graceful handling when keys are absent.

Also applies to: 147-147, 150-152


441-466: LGTM!

Single-task step resolution correctly derives num_steps from num_epoch when not explicitly provided, with appropriate validation and logging.


467-507: LGTM!

Multi-task step resolution properly handles both num_epoch_dict (epoch-based) and explicit numb_steps paths. The delegation to utility functions (resolve_model_prob_from_epochs, resolve_model_prob) keeps the trainer code clean.


716-722: LGTM!

The fallback ensures model_prob is always initialized for multi-task training, covering edge cases like resuming from checkpoints where the epoch-based path wasn't taken during initialization.

source/tests/pt/test_sampler.py (4)

34-52: LGTM!

The _SerialPool shim is a clean solution to avoid SemLock/CUDA initialization failures in restricted CI environments. The context manager protocol and map method are correctly implemented.

Also applies to: 79-80


82-151: LGTM!

The helper methods are well-structured for testing sampling behavior. _sample_multitask_counts correctly mirrors the trainer's task selection logic using dp_random.choice.


232-280: LGTM!

Comprehensive test validating that epoch-derived step counts produce equivalent sampling behavior to explicit step counts. The use of identical random seeds ensures deterministic comparison.


282-409: LGTM!

Thorough multi-task sampling stability test that validates both model selection and within-task system selection frequencies match expected probabilities.

deepmd/pd/train/training.py (3)

33-37: LGTM!

Imports and initialization are consistent with the PT backend implementation, maintaining code alignment across backends.

Also applies to: 138-138, 141-143


400-467: LGTM!

Step resolution logic correctly mirrors the PT implementation, with appropriate adjustments for Paddle's BatchSampler structure. Both single-task and multi-task paths are properly handled.


654-661: LGTM!

Fallback model_prob resolution is consistent with the PT backend, ensuring multi-task training always has task probabilities initialized.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@deepmd/utils/argcheck.py`:
- Around line 3540-3566: training_extra_check currently only enforces mutual
exclusivity but allows neither "numb_steps" nor "num_epoch" to be set in
single-task mode; update the function (training_extra_check) so that when
multi_task is False and both num_steps and num_epoch are None it raises a
ValueError indicating one of training.numb_steps or training.num_epoch must be
provided; keep existing mutual-exclusion checks and ensure the error message
names the symbols ("numb_steps"/"num_epoch") to aid users.
- Around line 3547-3560: In the multi_task branch add a validation that mirrors
single-task: if multi_task is True and both num_steps is None and num_epoch_dict
is empty/None, raise a ValueError; specifically inside the existing multi_task
handling block (alongside the checks referencing num_epoch, num_epoch_dict,
num_steps, and model_prob) add a check like: if not num_epoch_dict and num_steps
is None: raise ValueError("training.num_epoch_dict or training.num_step must be
set in multi-task mode."). Ensure you treat an empty dict as invalid (use a
falsy check) and keep the error message descriptive.

In `@source/tests/pt/test_sampler.py`:
- Around line 443-444: test_num_epoch_dict computes probs_1 and probs_2 via
self._normalize_probs(np.asarray(sampler_1.weights)) and
self._normalize_probs(np.asarray(sampler_2.weights)) but never uses them; remove
the unused variables or add assertions that validate the normalized
distributions. Update the test_num_epoch_dict function to either delete the
lines that assign probs_1 and probs_2, or replace them with meaningful checks
(e.g., assert sums equal 1, non-negativity, or expected values) referencing
_normalize_probs, sampler_1.weights and sampler_2.weights so the normalization
is actually verified.
🧹 Nitpick comments (2)
doc/train/multi-task-training.md (1)

82-90: Good documentation for the new num_epoch_dict parameter.

The explanation of the mutual exclusivity and the formula for computing model_prob and total steps is clear and useful.

One minor inconsistency: Line 86 uses "pretrained" while line 101 uses "pre-trained". Consider using consistent spelling throughout.

deepmd/dpmodel/utils/training_utils.py (1)

97-105: Consider logging when model_prob_config is provided.

Currently, there's an info log when model_prob_config is empty (defaulting to system counts), but no log when user-specified config is used. For consistency and debugging, consider logging the user-specified probabilities as well.

♻️ Optional enhancement
     if model_prob_config:
         missing = [k for k in model_keys if k not in model_prob_config]
         if missing:
             raise ValueError(
                 f"training.model_prob must specify all tasks; missing: {missing}"
             )
+        if rank == 0:
+            log.info("Using user-specified model_prob: %s", model_prob_config)
         for ii, model_key in enumerate(model_keys):
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a37f0c7 and b2b9bc4.

📒 Files selected for processing (9)
  • deepmd/dpmodel/utils/__init__.py
  • deepmd/dpmodel/utils/training_utils.py
  • deepmd/pd/train/training.py
  • deepmd/pt/train/training.py
  • deepmd/tf/entrypoints/change_bias.py
  • deepmd/tf/entrypoints/train.py
  • deepmd/utils/argcheck.py
  • doc/train/multi-task-training.md
  • source/tests/pt/test_sampler.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/utils/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-11-29T12:15:22.226Z
Learnt from: HydrogenSulfate
Repo: deepmodeling/deepmd-kit PR: 4414
File: deepmd/pd/train/training.py:66-66
Timestamp: 2024-11-29T12:15:22.226Z
Learning: The function `nvprof_context` is defined in `deepmd/pd/utils/utils.py`, so importing it in `deepmd/pd/train/training.py` is correct.

Applied to files:

  • source/tests/pt/test_sampler.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.

Applied to files:

  • deepmd/pd/train/training.py
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (1)
deepmd/dpmodel/utils/training_utils.py (3)
  • compute_total_numb_batch (12-64)
  • resolve_model_prob (67-121)
  • resolve_model_prob_from_epochs (124-188)
deepmd/pd/train/training.py (3)
deepmd/dpmodel/utils/training_utils.py (3)
  • compute_total_numb_batch (12-64)
  • resolve_model_prob (67-121)
  • resolve_model_prob_from_epochs (124-188)
deepmd/pd/utils/utils.py (3)
  • to_numpy_array (230-230)
  • to_numpy_array (234-234)
  • to_numpy_array (237-254)
deepmd/dpmodel/common.py (1)
  • to_numpy_array (106-128)
🪛 LanguageTool
doc/train/multi-task-training.md

[uncategorized] ~86-~86: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...ne-tuning scenarios where a data-rich pretrained model is jointly trained with a data-sc...

(EN_WORD_COHERENCY)

🪛 Ruff (0.14.11)
source/tests/pt/test_sampler.py

353-353: Avoid specifying long messages outside the exception class

(TRY003)


443-443: Local variable probs_1 is assigned to but never used

Remove assignment to unused variable probs_1

(F841)


444-444: Local variable probs_2 is assigned to but never used

Remove assignment to unused variable probs_2

(F841)

deepmd/pt/train/training.py

451-453: Avoid specifying long messages outside the exception class

(TRY003)


455-455: Avoid specifying long messages outside the exception class

(TRY003)


457-459: Avoid specifying long messages outside the exception class

(TRY003)


498-501: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/utils/argcheck.py

3549-3551: Avoid specifying long messages outside the exception class

(TRY003)


3554-3556: Avoid specifying long messages outside the exception class

(TRY003)


3558-3560: Avoid specifying long messages outside the exception class

(TRY003)


3563-3565: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pd/train/training.py

406-408: Avoid specifying long messages outside the exception class

(TRY003)


410-410: Avoid specifying long messages outside the exception class

(TRY003)


412-414: Avoid specifying long messages outside the exception class

(TRY003)


453-456: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/dpmodel/utils/training_utils.py

37-37: Avoid specifying long messages outside the exception class

(TRY003)


39-39: Avoid specifying long messages outside the exception class

(TRY003)


41-41: Avoid specifying long messages outside the exception class

(TRY003)


43-43: Avoid specifying long messages outside the exception class

(TRY003)


46-46: Avoid specifying long messages outside the exception class

(TRY003)


50-50: Avoid specifying long messages outside the exception class

(TRY003)


52-52: Avoid specifying long messages outside the exception class

(TRY003)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


56-56: Avoid specifying long messages outside the exception class

(TRY003)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


61-63: Avoid specifying long messages outside the exception class

(TRY003)


100-102: Avoid specifying long messages outside the exception class

(TRY003)


115-115: Avoid specifying long messages outside the exception class

(TRY003)


117-117: Avoid specifying long messages outside the exception class

(TRY003)


120-120: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


154-156: Avoid specifying long messages outside the exception class

(TRY003)


161-163: Avoid specifying long messages outside the exception class

(TRY003)


166-168: Avoid specifying long messages outside the exception class

(TRY003)


172-172: Avoid specifying long messages outside the exception class

(TRY003)


174-174: Avoid specifying long messages outside the exception class

(TRY003)


176-176: Avoid specifying long messages outside the exception class

(TRY003)


178-178: Avoid specifying long messages outside the exception class

(TRY003)


182-182: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/tf/entrypoints/change_bias.py

198-198: Avoid specifying long messages outside the exception class

(TRY003)


201-201: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/tf/entrypoints/train.py

265-267: Avoid specifying long messages outside the exception class

(TRY003)


269-269: Avoid specifying long messages outside the exception class

(TRY003)


271-273: Avoid specifying long messages outside the exception class

(TRY003)


278-278: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (21)
source/tests/pt/test_sampler.py (3)

16-22: Module-qualified import pattern looks good.

Using import deepmd.pt.utils.dataloader as pt_dataloader and accessing components via the module alias (e.g., pt_dataloader.DpLoaderSet) is cleaner than mixing import and from imports. This addresses the static analysis warning about mixed import styles.


34-45: Good CI workaround with _SerialPool.

The serial pool shim avoids SemLock/CUDA initialization issues in restricted CI environments while maintaining the same interface as the multiprocessing Pool.


411-513: Well-structured epoch calculation test.

The test_num_epoch_dict test thoroughly validates:

  1. Per-task total_numb_batch computation
  2. Step calculation from epoch targets
  3. Rounding consistency across tasks (scale_0 ≈ scale_1)

The assertions properly verify that rounding up (ceil) doesn't reduce expected epochs and scales all tasks consistently.

deepmd/utils/argcheck.py (1)

3216-3244: Documentation for numb_steps and num_epoch is clear and comprehensive.

The docstrings properly explain:

  • The mutual exclusivity between parameters
  • The formula for computing steps from epochs
  • The accepted aliases
  • Backend-specific behavior (single-task vs multi-task)
deepmd/dpmodel/utils/training_utils.py (3)

12-64: Well-implemented compute_total_numb_batch with thorough validation.

The function correctly computes ceil(max(n_batches[i] / prob[i])) over positive probability entries. The comprehensive input validation (1D arrays, finite values, non-negative, positive sum) will provide clear error messages for misconfigurations.


67-121: resolve_model_prob correctly handles both user-specified and default probabilities.

Good design choices:

  • Requires all model keys when user specifies model_prob_config
  • Defaults to system counts when config is empty
  • Logs the fallback behavior only on rank 0 (avoids log spam in distributed training)

124-188: resolve_model_prob_from_epochs correctly implements the documented formula.

The implementation matches the documented behavior:

  • model_prob[i] = num_epoch_dict[i] * per_task_total[i] / sum_j(...)
  • num_steps = ceil(sum_i(num_epoch_dict[i] * per_task_total[i]))

The validation for positive epoch values and matching array lengths is appropriate.

deepmd/tf/entrypoints/train.py (2)

260-285: Correct implementation of num_epoch support for TensorFlow backend.

The logic properly:

  1. Falls back to num_epoch when numb_steps is not set
  2. Validates num_epoch is positive
  3. Requires train_data for computing total_numb_batch
  4. Logs the derived values for transparency

The formula stop_batch = ceil(num_epoch * total_numb_batch) matches the documentation.


263-267: Runtime validation is necessary but ideally caught earlier.

This runtime check for missing both numb_steps and num_epoch is correct, but with the suggested fix to training_extra_check in argcheck.py, this would be caught during config validation instead. The runtime check provides defense-in-depth.

deepmd/tf/entrypoints/change_bias.py (2)

18-20: LGTM!

The import of compute_total_numb_batch from the shared utilities module aligns with the PR's goal of consolidating common epoch-based scheduling logic.


193-210: LGTM!

The logic to derive stop_batch from num_epoch is correctly implemented:

  • Validates num_epoch > 0 and total_numb_batch > 0 before computation
  • Uses np.ceil for conservative step counting
  • Maintains backward compatibility with the stop_batch = 0 fallback

This aligns well with the epoch-based scheduling introduced in the PT/PD trainers.

deepmd/pt/train/training.py (5)

26-30: LGTM!

The imports from deepmd.dpmodel.utils properly bring in the shared utilities for epoch-based scheduling. This consolidation aligns with the PR objective of sharing common methods between PT and PD code paths.


147-147: LGTM!

The explicit initialization of self.model_prob = None prevents potential AttributeError in edge cases where model_prob might be accessed before being set, addressing the concern from past reviews.


150-152: LGTM!

The configuration extraction correctly uses .get() to allow both numb_steps and num_epoch/num_epoch_dict to be optional, enabling the new epoch-based scheduling while maintaining backward compatibility.


732-738: LGTM!

The defensive fallback ensures model_prob is always resolved for multi-task training, even in edge cases where the earlier resolution paths might not have been executed (e.g., certain resuming scenarios). This addresses the concern raised in past reviews.


441-508: LGTM.

The training steps resolution logic is well-structured:

  • Single-task: derives num_steps from num_epoch × total_numb_batch
  • Multi-task with num_epoch_dict: uses resolve_model_prob_from_epochs to compute both model_prob and num_steps
  • Multi-task without num_epoch_dict: requires explicit num_steps and resolves model_prob separately

The shared utilities in training_utils.py handle edge case validation (empty weights, non-finite values, missing keys). training_data.index correctly provides the batch count per system (populated from len(system_dataloader) in the dataloader), which is passed to compute_total_numb_batch as expected.

deepmd/pd/train/training.py (5)

33-38: LGTM!

The imports correctly bring in the shared utilities from deepmd.dpmodel.utils. Note the BaseLR import on line 33 is separate from the utility imports, which is fine syntactically.


136-136: LGTM!

The explicit initialization of self.model_prob = None addresses the past review concern and maintains consistency with the PT trainer implementation.


139-141: LGTM!

The configuration extraction mirrors the PT implementation, correctly using .get() for the new epoch-based configuration options.


395-462: LGTM!

The resolution logic correctly mirrors the PT implementation while adapting to Paddle's API:

  • Uses batch_sampler.sampler.weights (Paddle) instead of sampler.weights (PyTorch)
  • Same validation flow and utility usage
  • Consistent error messages and logging

This maintains parity between the PT and PD trainers for the epoch-based scheduling feature.


649-656: LGTM!

The defensive fallback for model_prob resolution addresses the past review concern and maintains consistency with the PT trainer. This ensures multi-task training always has valid model probabilities, even in edge cases involving model resumption.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants