Skip to content

Conversation

waliwali777
Copy link

@waliwali777 waliwali777 commented Aug 21, 2025

Summary by CodeRabbit

  • New Features

    • Distributed multi-device training with mesh-aware execution, input sharding, and distributed loss handling.
    • Dynamic neighbor selection in descriptors, local mapping enabled by default.
    • Optional exponential switching for environment weighting and distance-based edge initialization.
    • New SiLUT activation with optional JIT-backed path (env-controlled).
  • Behavior Changes

    • Custom-op JIT disabled by default; device handling streamlined.
    • Data loading simplified; single-task validation often skipped; example training steps reduced.
  • Examples

    • Added multi-GPU run script and updated water/dpa3 config.
  • Chores

    • Pre-commit hooks temporarily disabled.

Copy link
Contributor

coderabbitai bot commented Aug 21, 2025

📝 Walkthrough

Walkthrough

Adds distributed training/forward sharding and fleet mesh integration; enables dynamic neighbor selection and graph-index utilities; introduces exponential switch weighting and SiLUT activation with optional JIT path; changes DPA3 default local mapping; several runtime/env and example script updates.

Changes

Cohort / File(s) Summary
Tooling config
./.pre-commit-config.yaml
Commented out prettier and bibtex-tidy hooks; other config unchanged.
Env & activation plumbing
deepmd/pd/utils/env.py, deepmd/pd/entrypoints/main.py, deepmd/pd/utils/utils.py, deepmd/pd/utils/spin.py, deepmd/pd/model/descriptor/repformers.py
Env reads PADDLE_LOCAL_RANK; adds CUSTOM_OP_USE_JIT export. main.train sets env.CUSTOM_OP_USE_JIT = False. Adds SiLUT scripted/Python variants gated by env flag. Removed device kwarg in some tensor allocations.
Distributed training & I/O
deepmd/pd/train/training.py, deepmd/pd/utils/dataloader.py, deepmd/pd/model/model/make_model.py, deepmd/pd/loss/ener.py, examples/water/dpa3/run.sh, examples/water/dpa3/input_torch.json
Initializes fleet mesh and fleet.auto usage; switches DataLoader to batch_size usage; shards input/label tensors before forward; wraps neighbor-list build and lower forward with dist.local_map; modifies energy/force/virial loss to reshard/replicate for distributed contexts; adds example run script and config changes.
Descriptor: dynamic selection & mapping
deepmd/pd/model/descriptor/dpa3.py, .../repflows.py, .../repflow_layer.py
DPA3 default use_loc_mapping=True. DescrptBlockRepflows: new params edge_init_use_dist, use_exp_switch, use_dynamic_sel, sel_reduce_factor, use_loc_mapping; computes edge/angle indices when dynamic and threads indices into layer calls. RepFlowLayer adds dynamic symmetrization/update helpers and extends forward signature to accept edge_index and angle_index.
Env-matrix & preprocessing
deepmd/pd/model/descriptor/env_mat.py, deepmd/pd/utils/preprocess.py
Adds use_exp_switch flag to env-matrix paths; env_mat uses padded coords; compute_smooth_weight simplified to clip-based polynomial; new compute_exp_sw exponential switch function.
Graph utilities
deepmd/pd/model/network/utils.py
New helpers aggregate(data, owners, average=True, num_owner=None) and get_graph_index(nlist, nlist_mask, a_nlist_mask, nall) for owner-wise aggregation and edge/angle index construction.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Trainer
  participant Fleet as fleet/mesh
  participant Dist as dist.local_map
  participant Model as forward_common_lower
  participant Data as DataLoader

  User->>Trainer: start training()
  Trainer->>Fleet: fleet.auto.create_mesh / fleet.init
  Trainer->>Data: next batch
  Trainer->>Trainer: build label_dict_spec if CINN
  Trainer->>Trainer: shard inputs/labels on mesh
  rect rgba(200,230,255,0.25)
    note over Trainer,Dist: Distributed stage 1
    Trainer->>Dist: extend_input_and_build_neighbor_list(...)
    Dist-->>Trainer: extended_coord, extended_atype, mapping, nlist
  end
  rect rgba(200,230,255,0.25)
    note over Trainer,Dist: Distributed stage 2
    Trainer->>Dist: forward_common_lower(..., extended tensors)
    Dist->>Model: compute predictions
    Model-->>Dist: model_predict_lower
    Dist-->>Trainer: results
  end
  Trainer->>Trainer: communicate outputs, cast, loss (dist.reshard)
  Trainer-->>User: step complete
Loading
sequenceDiagram
  autonumber
  participant DPA3 as DescrptDPA3
  participant RepBlk as DescrptBlockRepflows
  participant Utils as get_graph_index
  participant Layer as RepFlowLayer

  DPA3->>RepBlk: forward(..., use_dynamic_sel?, use_loc_mapping)
  alt use_dynamic_sel
    RepBlk->>Utils: get_graph_index(nlist, masks, nall)
    Utils-->>RepBlk: edge_index, angle_index
    RepBlk->>Layer: forward(..., edge_index, angle_index)
    note over Layer: Dynamic gather, aggregate, symmetrize, updates
  else static
    RepBlk->>Layer: forward(..., dummy indices)
    note over Layer: Static nlist-based path
  end
  RepBlk-->>DPA3: node_ebd, edge_ebd, h2, rot_mat, sw
Loading
sequenceDiagram
  autonumber
  participant Entry as entrypoints/main.train
  participant Env as env
  participant Act as ActivationFn

  Entry->>Env: set CUSTOM_OP_USE_JIT = False
  Act->>Env: read CUSTOM_OP_USE_JIT
  alt CUSTOM_OP_USE_JIT True
    Act-->>Act: use SiLUTScript (JIT)
  else False
    Act-->>Act: use SiLUT (Python)
  end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

Core

Suggested reviewers

  • njzjz
  • iProzd
  • caic99

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pd/utils/dataloader.py (1)

191-199: Reintroduce DistributedBatchSampler into DataLoader to restore per-rank sharding

The current implementation in deepmd/pd/utils/dataloader.py (around lines 191–199) constructs each DataLoader using batch_size, which ignores the previously built DistributedBatchSampler in self.sampler_list. As a result, all ranks will load the full dataset each epoch, duplicating samples across processes. To fix this:

• In the for system, batch_size in zip(self.systems, self.batch_sizes): loop
– Replace the batch_size=int(batch_size) argument with batch_sampler=system_batch_sampler.
– Remove the now-unused batch_size parameter.

• Immediately after constructing system_dataloader, update the bookkeeping to use the sampler’s length:
– Change self.index.append(len(system_dataloader))self.index.append(len(system_batch_sampler))
– Change self.total_batch += len(system_dataloader)self.total_batch += len(system_batch_sampler)

Apply:

--- a/deepmd/pd/utils/dataloader.py
+++ b/deepmd/pd/utils/dataloader.py
@@ -172,13 +172,12 @@ class DpLoaderSet(Dataset):
             if dist.is_available() and dist.is_initialized():
                 system_batch_sampler = DistributedBatchSampler(
                     system,
                     shuffle=shuffle,
                     batch_size=int(batch_size),
                 )
+                # ...
             else:
                 system_batch_sampler = BatchSampler(
                     batch_size=int(batch_size),
                 )
             self.sampler_list.append(system_batch_sampler)
-
-            system_dataloader = DataLoader(
-                dataset=system,
-                num_workers=0,  # Should be 0 to avoid too many threads forked
-                batch_size=int(batch_size),
+            system_dataloader = DataLoader(
+                dataset=system,
+                num_workers=0,  # Should be 0 to avoid too many threads forked
+                batch_sampler=system_batch_sampler,
                 collate_fn=collate_batch,
                 use_buffer_reader=False,
                 places=["cpu"],
             )
             self.dataloaders.append(system_dataloader)
-            self.index.append(len(system_dataloader))
-            self.total_batch += len(system_dataloader)
+            self.index.append(len(system_batch_sampler))
+            self.total_batch += len(system_batch_sampler)

This ensures each process only loads its shard of data and that epoch sizing is consistent across ranks.

🧹 Nitpick comments (26)
.pre-commit-config.yaml (2)

68-74: Don’t drop Prettier entirely; pin a stable rev and scope it to docs.

Commenting out the Prettier hook removes consistent formatting for Markdown/YAML/CSS across contributors. Rather than disabling, re-enable it with a stable tag and keep the existing excludes to avoid pre-commit.ci issues.

Apply:

-  # - repo: https://github.com/pre-commit/mirrors-prettier
-  #   rev: v4.0.0-alpha.8
-  #   hooks:
-  #     - id: prettier
-  #       types_or: [markdown, yaml, css]
-  #       # workflow files cannot be modified by pre-commit.ci
-  #       exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
+  - repo: https://github.com/pre-commit/mirrors-prettier
+    rev: v3.3.3
+    hooks:
+      - id: prettier
+        types_or: [markdown, yaml, css]
+        # workflow files cannot be modified by pre-commit.ci
+        exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)

86-104: Keep bibtex-tidy but make it manual-only to avoid noise.

If the goal is to avoid unintended reordering or CI edits, keep the hook gated to manual runs so maintainers can enforce style when needed.

Apply:

-  # - repo: https://github.com/njzjz/mirrors-bibtex-tidy
-  #   rev: v1.13.0
-  #   hooks:
-  #     - id: bibtex-tidy
-  #       args:
-  #         - --curly
-  #         - --numeric
-  #         - --align=13
-  #         - --blank-lines
-  #         # disable sort: the order of keys and fields has explict meanings
-  #         #- --sort=key
-  #         - --duplicates=key,doi,citation,abstract
-  #         - --merge=combine
-  #         #- --sort-fields
-  #         #- --strip-comments
-  #         - --trailing-commas
-  #         - --encode-urls
-  #         - --remove-empty-fields
-  #         - --wrap=80
+  - repo: https://github.com/njzjz/mirrors-bibtex-tidy
+    rev: v1.13.0
+    hooks:
+      - id: bibtex-tidy
+        stages: [manual]
+        args:
+          - --curly
+          - --numeric
+          - --align=13
+          - --blank-lines
+          # disable sort: the order of keys and fields has explict meanings
+          #- --sort=key
+          - --duplicates=key,doi,citation,abstract
+          - --merge=combine
+          #- --sort-fields
+          #- --strip-comments
+          - --trailing-commas
+          - --encode-urls
+          - --remove-empty-fields
+          - --wrap=80
examples/water/dpa3/input_torch.json (1)

88-95: numb_steps=2000 clashes with decay_steps=5000 and save_freq=2000.

With only 2000 steps you won't complete the LR decay schedule (decay_steps=5000), and save_freq=2000 yields at most one checkpoint. If this example is meant for quick verification, consider saving more frequently and/or aligning the LR schedule.

Minimal adjustment inside this block to ensure at least a couple of checkpoints:

-    "numb_steps": 2000,
+    "numb_steps": 2000,
@@
-    "save_freq": 2000,
+    "save_freq": 500,

Optionally align LR decay with the shortened run (outside this block):

-  "decay_steps": 5000,
+  "decay_steps": 2000,

Please confirm the intended purpose of this config (smoke test vs. meaningful short run) so we can tune disp_freq/save_freq/decay accordingly.

examples/water/dpa3/run.sh (3)

1-18: Add a shebang, fail-fast flags, and ensure log dir exists.

Without a shebang, the shell is unspecified (ShellCheck SC2148). Also, add strict modes to fail fast and create the logs directory before writing to it.

+#!/usr/bin/env bash
+set -Eeuo pipefail
+
 # unset PADDLE_ELASTIC_JOB_ID
 # unset PADDLE_TRAINER_ENDPOINTS
 # unset DISTRIBUTED_TRAINER_ENDPOINTS
 # unset FLAGS_START_PORT
 # unset PADDLE_ELASTIC_TIMEOUT
 # export NNODES=1
 # export PADDLE_TRAINERS_NUM=1
 unset CUDA_DEVICE_MAX_CONNECTIONS
 
-HDFS_USE_FILE_LOCKING=0 python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir "logs" dp --pd train input_torch.json -l dp_train.log
+mkdir -p logs
+HDFS_USE_FILE_LOCKING=0 python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir "logs" dp --pd train input_torch.json -l dp_train.log

10-10: Consider parameterizing device selection and using paddle.distributed.run (if supported).

Hardcoding GPU IDs (0–7) can be brittle across environments. Consider reading from CUDA_VISIBLE_DEVICES or a script argument. Also, newer Paddle versions favor paddle.distributed.run over launch; confirm your target Paddle version and migrate if applicable.

Would you like a minimal wrapper to accept --gpus or default to CUDA_VISIBLE_DEVICES and fall back to all visible GPUs?


14-17: Scrub or anonymize commented private IPs or move them to a template.

Even commented-out internal IPs can cause confusion in public repos. Replace with placeholders to reduce noise.

-#    --ips=10.67.200.17,10.67.200.11,10.67.200.13,10.67.200.15 \
+#    --ips=<node1_ip>,<node2_ip>,<node3_ip>,<node4_ip> \
deepmd/pd/model/model/make_model.py (1)

423-524: Minor typos in docstring/comments.

“Nothong” → “nothing”, “effact” → “effect”, “mapps” → “maps”. Purely cosmetic, but these are user-facing docstrings.

deepmd/pd/model/descriptor/env_mat.py (1)

38-43: Switch selection is fine; consider naming consistency (rcut_smth) and documenting behavior.

Parameter ruct_smth appears to be rcut_smth elsewhere (prod_env_mat). It works positionally, but the name typo can mislead readers. Also, mention that compute_exp_sw yields ~exp(-1) at rmin and decays faster than the polynomial near rcut.

-    ruct_smth: float,
+    rcut_smth: float,
 ...
-        compute_smooth_weight(length, ruct_smth, rcut)
+        compute_smooth_weight(length, rcut_smth, rcut)
 ...
-        else compute_exp_sw(length, ruct_smth, rcut)
+        else compute_exp_sw(length, rcut_smth, rcut)
deepmd/pd/loss/ener.py (1)

7-7: DTensor import is appropriate; ensure single-process paths still work.

Adding paddle.distributed as dist is aligned with the rest of the PR. Just make sure all dist-only ops are guarded so single-device training/inference continues to run.

If you prefer, centralize the guards via a small helper (see earlier comment) and reuse it for energy/force/virial paths.

deepmd/pd/model/network/utils.py (1)

89-105: Commented-out code should be removed

Lines 88-89 contain commented-out code calculating n_angle which appears to be unused. Consider removing these lines to keep the code clean.

     a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
     n_edge = nlist_mask.sum().item()
-    # n_angle = a_nlist_mask_3d.sum().item()
-
deepmd/pd/model/descriptor/dpa3.py (1)

496-496: Unused variable nall

The variable nall is computed but never used in the function. Consider removing it to clean up the code.

     extended_coord = extended_coord.to(dtype=self.prec)
     nframes, nloc, nnei = nlist.shape
-    nall = extended_coord.reshape([nframes, -1]).shape[1] // 3
deepmd/pd/train/training.py (3)

103-106: Hardcoded mesh dimension may limit scalability

The mesh dimension is hardcoded to 32 (mesh_dims = [("dp", 32)]). This may not be optimal for all training scenarios and hardware configurations.

Consider making the mesh dimension configurable through the training parameters:

-        mesh_dims = [("dp", 32)]
+        mesh_dims = [("dp", training_params.get("mesh_dim", 32))]

Would you like me to open an issue to track making the distributed mesh configuration more flexible?


632-632: Simplify dictionary key check

Use key in dict instead of key in dict.keys() for better readability and performance.

-                k: spec_templates[k] for k in label_dict.keys() if k in spec_templates
+                k: spec_templates[k] for k in label_dict if k in spec_templates

747-751: Unused synchronization context variable

The sync_context variable is defined but never used. The actual synchronization code is commented out (lines 753-769). Consider either removing the unused code or implementing the synchronization properly.

If synchronization is not needed yet, remove the unused variable:

-                sync_context = (
-                    self.wrapper.no_sync
-                    if self.world_size > 1
-                    else contextlib.nullcontext
-                )
-

Or if synchronization will be implemented soon, consider adding a TODO comment explaining the plan.

deepmd/pd/model/descriptor/repflows.py (6)

226-241: Validation on dynamic selection and sel_reduce_factor is appropriate

The “dynamic requires smooth” constraint and sel_reduce_factor > 0 check are good early failures. Consider upgrading NotImplementedError to ValueError for invalid user configuration (it’s not an implementation missing, but an invalid combination).

-            raise NotImplementedError(
+            raise ValueError(
                 "smooth_edge_update must be True when use_dynamic_sel is True!"
             )

405-407: Parallel path still unimplemented; assert is fine but message could be clearer

In non-parallel mode you assert mapping is present; in parallel you still raise NotImplementedError downstream. Consider raising a clear error here if comm_dict is not None to fail fast with guidance (e.g., “parallel_mode not yet supported in se_repflow.forward”).

-        parallel_mode = comm_dict is not None
-        if not parallel_mode:
+        parallel_mode = comm_dict is not None
+        if not parallel_mode:
             assert mapping is not None
+        else:
+            raise NotImplementedError("parallel_mode is not yet supported in DescrptBlockRepflows.forward().")

452-456: Converting -1 paddings to 0: confirm index 0 is always safe

You zero out padded neighbor indices before later take_along_axis/indexing. Please confirm index 0 is guaranteed valid (and masked) for all frames. If not, use a dedicated out-of-range index for gather plus masked_fill afterwards to avoid accidental leakage from atom 0.

If 0 might be a real atom, prefer: shift indices to a safe dummy row by padding node_ebd_ext/edge tensors with a leading zero row and indexing that.


457-466: Remove unused local n_dim

n_dim is assigned from node_ebd but unused. Trim it to keep the forward tight.

-        n_dim = node_ebd.shape[-1]

485-494: Local mapping application is correct; add dtype/device assertions for robustness

Mapping-based conversion from global to local neighbor indices is a good move. Add quick asserts to ensure mapping dtype is integer and on the same device as nlist to prevent subtle JIT/device issues.

         if not parallel_mode and self.use_loc_mapping:
             assert mapping is not None
+            assert mapping.dtype in (paddle.int32, paddle.int64)
+            assert mapping.place == nlist.place

561-575: Scale factor matches layer’s dynamic normalization; consider reusing a single definition

You use (self.nnei / self.sel_reduce_factor) ** (-0.5) here and self.dynamic_e_sel ** (-0.5) in RepFlowLayer. Unify to one source of truth (e.g., compute once in init and reuse) to avoid drift.

-                scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5),
+                scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5),
+                # Consider: precompute self.dynamic_e_sel = self.nnei / self.sel_reduce_factor in __init__
+                # and reference it both here and in RepFlowLayer.
deepmd/pd/model/descriptor/repflow_layer.py (6)

332-385: Dynamic HG computation: check aggregate semantics and owner dtype

The implementation is clean and matches the math. Two small safeguards:

  • Ensure owner (n2e_index) is int64 on the same device as flat_h2g2 to avoid implicit casts in aggregate.
  • Consider guarding against zero-division when dynamic_e_sel can be < 1 for tiny sel_reduce_factor choices.
-        h2g2 = (
+        # owners must be integer indices on same place
+        h2g2 = (
             aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(
                 [nb, nloc, 3, e_dim]
             )
             * scale_factor
         )

If you want, I can add explicit type/device asserts before calling aggregate.


558-607: optim_angle_update_dynamic: matrix partitioning looks correct; add small JIT/shape guards

Given frequent JIT use, add asserts on the split sizes against matrix.shape[0] to catch config mismatches early.

-        sub_angle, sub_node, sub_edge_ik, sub_edge_ij = paddle.split(
+        # Sanity check partitioning to avoid silent shape drift
+        assert matrix.shape[0] == angle_dim + node_dim + 2 * edge_dim
+        sub_angle, sub_node, sub_edge_ik, sub_edge_ij = paddle.split(
             matrix, [angle_dim, node_dim, edge_dim, edge_dim]
         )

643-688: optim_edge_update_dynamic: same partitioning guard and a tiny readability win

Mirror the partitioning assert here and name the flattened node dims for readability.

-        node, node_ext, edge = paddle.split(matrix, [node_dim, node_dim, edge_dim])
+        assert matrix.shape[0] == node_dim + node_dim + edge_dim
+        node, node_ext, edge = paddle.split(matrix, [node_dim, node_dim, edge_dim])

789-811: Dynamic sym op uses identical scaling as HG; nice consistency

The use of self.dynamic_e_sel ** (-0.5) matches the outer HG scaling. Consider computing dynamic_e_sel in DescrptBlockRepflows once, passing it down, to ensure both places stay consistent if the definition evolves.


836-903: Node-edge message reduction: correct aggregate path; consider int dtype assertion for owners

Aggregate with average=False then explicit division by dynamic_e_sel is clear and matches static normalization. Add an int dtype check on n2e_index to avoid hard-to-debug runtime errors when JIT tracing.

-                aggregate(
+                # n2e_index must be integer indices
+                aggregate(
                     node_edge_update,
                     n2e_index,
                     average=False,
                     num_owner=nb * nloc,
                 ).reshape(nb, nloc, node_edge_update.shape[-1])

1042-1072: Dynamic angle-to-edge reduction: scaling matches intent; add owner count check

Dividing by (self.dynamic_a_sel**0.5) parallels the static path. Optionally, assert num_owner == n_edge to catch mismatches when eij2a_index is constructed.

-                padding_edge_angle_update = aggregate(
+                padding_edge_angle_update = aggregate(
                     weighted_edge_angle_update,
                     eij2a_index,
                     average=False,
                     num_owner=n_edge,
                 ) / (self.dynamic_a_sel**0.5)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b494a0d and 09dc8dc.

📒 Files selected for processing (18)
  • .pre-commit-config.yaml (2 hunks)
  • deepmd/pd/entrypoints/main.py (2 hunks)
  • deepmd/pd/loss/ener.py (4 hunks)
  • deepmd/pd/model/descriptor/dpa3.py (5 hunks)
  • deepmd/pd/model/descriptor/env_mat.py (6 hunks)
  • deepmd/pd/model/descriptor/repflow_layer.py (17 hunks)
  • deepmd/pd/model/descriptor/repflows.py (10 hunks)
  • deepmd/pd/model/descriptor/repformers.py (2 hunks)
  • deepmd/pd/model/model/make_model.py (2 hunks)
  • deepmd/pd/model/network/utils.py (1 hunks)
  • deepmd/pd/train/training.py (9 hunks)
  • deepmd/pd/utils/dataloader.py (1 hunks)
  • deepmd/pd/utils/env.py (3 hunks)
  • deepmd/pd/utils/preprocess.py (1 hunks)
  • deepmd/pd/utils/spin.py (0 hunks)
  • deepmd/pd/utils/utils.py (3 hunks)
  • examples/water/dpa3/input_torch.json (1 hunks)
  • examples/water/dpa3/run.sh (1 hunks)
💤 Files with no reviewable changes (1)
  • deepmd/pd/utils/spin.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-05T03:11:02.922Z
Learnt from: 1azyking
PR: deepmodeling/deepmd-kit#4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-05T03:11:02.922Z
Learning: In `deepmd/pt/loss/ener_hess.py`, the `label` uses the key `"atom_ener"` intentionally to maintain consistency with the forked version.

Applied to files:

  • deepmd/pd/loss/ener.py
🧬 Code Graph Analysis (6)
deepmd/pd/model/model/make_model.py (2)
deepmd/pd/utils/nlist.py (1)
  • extend_input_and_build_neighbor_list (19-49)
deepmd/pt/model/model/make_model.py (1)
  • forward_common_lower (234-304)
deepmd/pd/model/descriptor/env_mat.py (1)
deepmd/pd/utils/preprocess.py (2)
  • compute_exp_sw (20-29)
  • compute_smooth_weight (9-17)
deepmd/pd/train/training.py (4)
deepmd/utils/data_system.py (1)
  • get_data (764-810)
deepmd/pt/train/training.py (1)
  • get_data (1096-1138)
source/tests/pd/model/test_saveload_dpa1.py (1)
  • get_data (117-134)
deepmd/pd/utils/utils.py (1)
  • nvprof_context (357-366)
deepmd/pd/utils/utils.py (2)
deepmd/pt/utils/utils.py (17)
  • silut_forward (24-30)
  • sigmoid (137-138)
  • silu (140-141)
  • silut_backward (33-43)
  • silut_double_backward (46-68)
  • SiLUTScript (71-130)
  • get_script_code (84-127)
  • SiLUTFunction (89-105)
  • forward (91-96)
  • forward (109-114)
  • forward (129-130)
  • forward (151-155)
  • forward (177-203)
  • backward (99-105)
  • backward (117-125)
  • SiLUTGradFunction (107-125)
  • SiLUT (133-155)
deepmd/pd/model/network/network.py (1)
  • Tensor (30-33)
deepmd/pd/model/descriptor/repflow_layer.py (2)
deepmd/pd/model/network/utils.py (1)
  • aggregate (9-44)
deepmd/pd/model/descriptor/repformer_layer.py (1)
  • _make_nei_g1 (86-116)
deepmd/pd/model/descriptor/repflows.py (2)
deepmd/pt/model/network/utils.py (1)
  • get_graph_index (54-143)
deepmd/pd/model/descriptor/repflow_layer.py (2)
  • _cal_hg (288-330)
  • _cal_hg_dynamic (333-385)
🪛 Ruff (0.12.2)
deepmd/pd/loss/ener.py

210-210: Undefined name tmp

(F821)


210-210: Undefined name tmp

(F821)

deepmd/pd/train/training.py

632-632: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


747-747: Local variable sync_context is assigned to but never used

Remove assignment to unused variable sync_context

(F841)

deepmd/pd/model/descriptor/dpa3.py

496-496: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

deepmd/pd/model/descriptor/repflows.py

465-465: Local variable n_dim is assigned to but never used

Remove assignment to unused variable n_dim

(F841)

🪛 Shellcheck (0.10.0)
examples/water/dpa3/run.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🔇 Additional comments (18)
deepmd/pd/utils/env.py (1)

75-75: Good: CUSTOM_OP_USE_JIT is explicitly exported.

Making the flag public in all is helpful for env-driven toggling from entrypoints.

Also applies to: 200-204

deepmd/pd/model/descriptor/repformers.py (1)

268-275: LGTM: .to(env.DEVICE) usage is correct and consistent.

Paddle’s Tensor.to supports device as the first positional argument (e.g., "gpu" or "gpu:0"), so this change is valid and simplifies the calls. (paddlepaddle.org.cn)

Also applies to: 506-512

deepmd/pd/utils/preprocess.py (1)

13-18: The clipped smooth-switch refactor looks good and keeps boundary behavior.

Clipping distance to [rmin, rmax] simplifies the previous mask-based logic while preserving C2 continuity and expected values at rmin→1 and rmax→0. Nice clean-up.

deepmd/pd/model/model/make_model.py (2)

230-235: Check that mapping/placements align with communicate_extended_output expectations.

communicate_extended_output assumes a particular sharding on mapping and on the model_predict_lower leaves (often Shard over batch/frame). Verify wrapped_func_2 outputs match those assumptions; otherwise, a re-reshard before communication may be needed.

Would you like a small probe test that asserts placements per leaf before and after communicate_extended_output?


215-229: The scripts above will show the full FittingOutputDef implementation and the get_data method so we can confirm how many leaves are produced and the exact structure of the output dict. Once we have that, we can verify whether hard-coding out_placements to length 6 is incorrect and update the review accordingly.

deepmd/pd/utils/utils.py (3)

38-83: Well-structured SiLU activation implementation

The forward, backward, and double backward functions are cleanly implemented with proper gradient handling and branch logic for the threshold-based switching between SiLU and tanh.


193-197: Conditional JIT usage based on environment flag

Good approach to conditionally enable JIT compilation via the environment flag, allowing for flexibility between training efficiency and debugging needs.


116-117: Incorrect API usage: saved_tensor() should be saved_tensors

The PaddlePaddle PyLayer uses saved_tensors (plural) to retrieve saved tensors from context, not saved_tensor(). This will cause an AttributeError at runtime.

Apply this diff to fix the API usage:

-                (x,) = ctx.saved_tensor()
+                (x,) = ctx.saved_tensors
-                (x, grad_output) = ctx.saved_tensor()
+                (x, grad_output) = ctx.saved_tensors

Likely an incorrect or invalid review comment.

deepmd/pd/entrypoints/main.py (1)

239-239: Disabling JIT during training initialization

Setting env.CUSTOM_OP_USE_JIT = False at the start of training ensures consistent behavior. This aligns with the conditional JIT usage in the SiLUT activation function.

deepmd/pd/model/descriptor/dpa3.py (1)

498-501: Local mapping optimization for non-parallel execution

Good optimization to use local mapping when not in parallel mode, reducing memory usage by only applying type embedding to local atoms instead of all extended atoms.

deepmd/pd/train/training.py (2)

771-783: Input tensor sharding for distributed training

Good implementation of tensor sharding across the distributed mesh for efficient parallel training. The sharding is correctly applied to both input tensors and label tensors.


861-863: Validation disabled with unclear reasoning

The validation step is commented out with only "no run valid!" as explanation. This could impact model evaluation during training.

Is validation intentionally disabled for a specific reason? If this is temporary, consider adding a more detailed comment explaining why validation is skipped and when it will be re-enabled.

deepmd/pd/model/descriptor/repflows.py (3)

115-141: Docstring for new switches and dynamic selection needs precision and guardrails

  • The exponential switch formula and its parameter guidance are user-facing. Please confirm they exactly match the implementation in prod_env_mat and any compute_exp_sw helper to avoid doc-code drift.
  • Consider briefly stating how use_exp_switch coexists with e_rcut_smth defaults and whether rcut_smth < rcut is enforced.
  • For dynamic selection, the text claims users can “safely set” very large e_sel/a_sel; add a practical upper bound note, because angle masks scale as O(a_sel^2) in memory even when later reduced.

Would you like me to sync the docstring with the current prod_env_mat logic and add a short “Performance notes” subsection?


191-196: New init options look good; defaulting use_loc_mapping=True is sensible

Defaulting to local mapping for non-parallel mode aligns with downstream usage and avoids accidental global-index assumptions in non-parallel inference. No change requested.


495-516: Dynamic path flattening: confirm consistency with get_graph_index semantics

The flattening is consistent, but relies on get_graph_index returning indices aligned with the pre/post local-mapping state. Since you conditionally convert nlist to local indices earlier, ensure use_loc_mapping passed to get_graph_index matches that conversion (it does here). No change requested.

deepmd/pd/model/descriptor/repflow_layer.py (3)

459-515: symmetrization_op_dynamic mirrors the static path well

Good factoring and re-use of _cal_grrg; keeps shapes consistent and avoids repeated matmuls. No change requested.


959-961: Angle compression via slicing: ensure dims align with a_compress_rate logic

Slicing is fine when a_compress_use_split=True. It assumes n_dim and e_dim ≥ computed compress dims; you already assert divisibility earlier. No change requested.


1213-1216: Serialization includes dynamic flags; consider bumping version or documenting backward-compat

You bumped @Version to 2 and added use_dynamic_sel and sel_reduce_factor. Ensure deserialize’s compatibility check (2, 1) covers older models properly and that omission of the new keys defaults to False/10.0 respectively.

Comment on lines 209 to 214
logit = energy_pred - energy_label
logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()])

l2_ener_loss = paddle.mean(paddle.square(logit))
if not self.inference:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined name and guard reshard for non-DTensor tensors.

tmp is undefined. You likely intended to reshard logit. Also, guard reshard calls to avoid failures when running in non-distributed contexts where logit has no process_mesh.

-                logit = energy_pred - energy_label
-                logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()])
+                logit = energy_pred - energy_label
+                if hasattr(logit, "process_mesh"):
+                    logit = dist.reshard(logit, logit.process_mesh, [dist.Replicate()])

Outside this hunk, consider adding a tiny helper to reduce repetition:

def _ensure_replicated(x):
    return dist.reshard(x, x.process_mesh, [dist.Replicate()]) if hasattr(x, "process_mesh") else x

Then use: logit = _ensure_replicated(logit)

🧰 Tools
🪛 Ruff (0.12.2)

210-210: Undefined name tmp

(F821)


210-210: Undefined name tmp

(F821)

🤖 Prompt for AI Agents
In deepmd/pd/loss/ener.py around lines 209 to 213, the code calls
dist.reshard(tmp, tmp.process_mesh, ...) but tmp is undefined and reshard should
be applied to logit; also the reshard call must be guarded to avoid errors when
tensors are not DTensor/have no process_mesh. Replace the reshard target from
tmp to logit and wrap the reshard call with a check (e.g., if hasattr(logit,
"process_mesh") or similar) so reshard is only invoked for distributed DTensor
objects; optionally add a small helper like _ensure_replicated(x) that returns
dist.reshard(x, x.process_mesh, [dist.Replicate()]) when x has process_mesh else
x, then use it for logit.

Comment on lines 265 to 266
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Reshard should be conditional; avoid hard-failing on single-device runs.

diff_f may not carry DTensor metadata in non-mesh runs. Guard the reshard.

-            diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
+            if hasattr(diff_f, "process_mesh"):
+                diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
if hasattr(diff_f, "process_mesh"):
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
🤖 Prompt for AI Agents
In deepmd/pd/loss/ener.py around lines 265-266, the unconditional call to
dist.reshard(diff_f, ...) can fail when diff_f is a regular tensor without
DTensor/mesh metadata; wrap the reshard call in a guard that checks for DTensor
metadata (e.g., if hasattr(diff_f, "process_mesh") and diff_f.process_mesh is
not None or a similar check for DTensor type) and only call dist.reshard when
that check passes, otherwise skip resharding (or leave diff_f as-is).

Comment on lines +362 to 364
diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
l2_virial_loss = paddle.mean(paddle.square(diff_v))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Same conditional reshard for virial difference.

Align with the pattern above to avoid attribute errors off-mesh.

-            diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
+            if hasattr(diff_v, "process_mesh"):
+                diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
l2_virial_loss = paddle.mean(paddle.square(diff_v))
if hasattr(diff_v, "process_mesh"):
diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()])
l2_virial_loss = paddle.mean(paddle.square(diff_v))
🤖 Prompt for AI Agents
In deepmd/pd/loss/ener.py around lines 362 to 363, the virial diff tensor is
reshared unconditionally causing potential attribute errors when off the process
mesh; follow the pattern used above by checking if diff_v has a process_mesh
attribute (or comparing diff_v.process_mesh against the target) and only call
dist.reshard when needed, then compute l2_virial_loss from the (possibly
reshared) diff_v.

Comment on lines 26 to 31
nlist = paddle.where(mask, nlist, nall - 1)
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
coord_r = paddle.take_along_axis(coord, axis=1, indices=index)
coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index)
coord_r = coord_r.reshape([bsz, natoms, nnei, 3])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Padded sentinel is never indexed; nlist still redirects invalid entries to nall-1.

You append one extra coordinate (coord_pad has length nall+1), but invalid nlist entries are replaced with nall - 1, which points to the last real atom, not the sentinel at index nall. This makes the padding ineffective.

-    nlist = paddle.where(mask, nlist, nall - 1)
+    # Redirect masked neighbors to the padded sentinel at index nall
+    nlist = paddle.where(mask, nlist, paddle.full_like(nlist, nall))
     coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
     index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
-    coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
+    coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
     coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index)

Note: The specific sentinel value is irrelevant because weight and diff are masked; the key is to avoid out-of-bounds and to keep gradients defined.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
nlist = paddle.where(mask, nlist, nall - 1)
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
coord_r = paddle.take_along_axis(coord, axis=1, indices=index)
coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index)
coord_r = coord_r.reshape([bsz, natoms, nnei, 3])
# Redirect masked neighbors to the padded sentinel at index nall
nlist = paddle.where(mask, nlist, paddle.full_like(nlist, nall))
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index)
coord_r = coord_r.reshape([bsz, natoms, nnei, 3])

Comment on lines 749 to 761
nb, nloc, nnei, _ = edge_ebd.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
n_edge = int(nlist_mask.sum().item())
if paddle.in_dynamic_mode():
assert [nb, nloc] == node_ebd.shape[:2]
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei] == h2.shape[:3]
if not self.use_dynamic_sel:
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei, 3] == h2.shape
else:
if paddle.in_dynamic_mode():
assert [n_edge, 3] == h2.shape
del a_nlist # may be used in the future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: edge_ebd is flat in dynamic mode; unpacking nb, nloc, nnei from it will crash

In dynamic mode, edge_ebd is n_edge x e_dim (flattened), but you unconditionally do nb, nloc, nnei, _ = edge_ebd.shape. This will raise at runtime and also mis-derive nloc/nnei. Derive shapes from node_ebd_ext and nlist instead.

Apply this minimal fix:

-        nb, nloc, nnei, _ = edge_ebd.shape
-        nall = node_ebd_ext.shape[1]
-        node_ebd = node_ebd_ext[:, :nloc, :]
+        # Always derive nb/nloc/nnei from node_ebd_ext/nlist to support both static and dynamic paths
+        nb, nall, _ = node_ebd_ext.shape
+        nloc = nlist.shape[1]
+        nnei = nlist.shape[2]
+        node_ebd = node_ebd_ext[:, :nloc, :]

The subsequent dynamic/non-dynamic assertions will continue to work (nnei is only used in the non-dynamic assertion).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
nb, nloc, nnei, _ = edge_ebd.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
n_edge = int(nlist_mask.sum().item())
if paddle.in_dynamic_mode():
assert [nb, nloc] == node_ebd.shape[:2]
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei] == h2.shape[:3]
if not self.use_dynamic_sel:
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei, 3] == h2.shape
else:
if paddle.in_dynamic_mode():
assert [n_edge, 3] == h2.shape
del a_nlist # may be used in the future
# Always derive nb/nloc/nnei from node_ebd_ext/nlist to support both static and dynamic paths
nb, nall, _ = node_ebd_ext.shape
nloc = nlist.shape[1]
nnei = nlist.shape[2]
node_ebd = node_ebd_ext[:, :nloc, :]
n_edge = int(nlist_mask.sum().item())
if paddle.in_dynamic_mode():
assert [nb, nloc] == node_ebd.shape[:2]
if not self.use_dynamic_sel:
if paddle.in_dynamic_mode():
assert [nb, nloc, nnei, 3] == h2.shape
else:
if paddle.in_dynamic_mode():
assert [n_edge, 3] == h2.shape
del a_nlist # may be used in the future
🧰 Tools
🪛 Ruff (0.12.2)

750-750: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🤖 Prompt for AI Agents
In deepmd/pd/model/descriptor/repflow_layer.py around lines 749-761, the code
unconditionally unpacks edge_ebd.shape into nb, nloc, nnei, _ which fails in
dynamic mode because edge_ebd is flattened (n_edge x e_dim); instead, derive nb
and nloc from node_ebd_ext.shape and derive nnei from the static nlist shape (or
set nnei only when edge_ebd has 4 dims). Concretely: if edge_ebd.ndim == 4
unpack nb,nloc,nnei,_ from edge_ebd.shape; otherwise (flattened dynamic case)
set nb,nloc = node_ebd_ext.shape[:2] and leave nnei unused (or compute from
nlist if needed), keeping n_edge = int(nlist_mask.sum().item()) and the existing
dynamic/non-dynamic assertions unchanged.

Comment on lines 186 to 190
wrapped_func_1 = dist.local_map(
func=lambda a, b, c: extend_input_and_build_neighbor_list(
a, b, self.get_rcut(), self.get_sel(), True, c
),
in_placements=[ele.placements for ele in [cc, atype, bb]],
out_placements=[[dist.Shard(0)] for _ in range(4)],
process_mesh=fleet.auto.get_mesh(),
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard in_placements against None/non-DTensor inputs; box can be None and atype may be dense.

Accessing .placements on bb (box) when box is None, or on atype if it’s not a DTensor, will raise. Provide a safe default (Replicate) and handle None.

-            wrapped_func_1 = dist.local_map(
-                func=lambda a, b, c: extend_input_and_build_neighbor_list(
-                    a, b, self.get_rcut(), self.get_sel(), True, c
-                ),
-                in_placements=[ele.placements for ele in [cc, atype, bb]],
+            wrapped_func_1 = dist.local_map(
+                func=lambda a, b, c: extend_input_and_build_neighbor_list(
+                    a, b, self.get_rcut(), self.get_sel(), True, c
+                ),
+                in_placements=[
+                    (getattr(ele, "placements", [dist.Replicate()]) if ele is not None else [dist.Replicate()])
+                    for ele in (cc, atype, bb)
+                ],
                 out_placements=[[dist.Shard(0)] for _ in range(4)],
                 process_mesh=fleet.auto.get_mesh(),
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
wrapped_func_1 = dist.local_map(
func=lambda a, b, c: extend_input_and_build_neighbor_list(
a, b, self.get_rcut(), self.get_sel(), True, c
),
in_placements=[ele.placements for ele in [cc, atype, bb]],
out_placements=[[dist.Shard(0)] for _ in range(4)],
process_mesh=fleet.auto.get_mesh(),
)
wrapped_func_1 = dist.local_map(
func=lambda a, b, c: extend_input_and_build_neighbor_list(
a, b, self.get_rcut(), self.get_sel(), True, c
),
in_placements=[
(getattr(ele, "placements", [dist.Replicate()]) if ele is not None else [dist.Replicate()])
for ele in (cc, atype, bb)
],
out_placements=[[dist.Shard(0)] for _ in range(4)],
process_mesh=fleet.auto.get_mesh(),
)
🤖 Prompt for AI Agents
In deepmd/pd/model/model/make_model.py around lines 186–194, the in_placements
list construction accesses .placements on bb (box) and atype unguarded, which
will raise if bb is None or atype is a non-DTensor; replace the list
comprehension with a safe extraction that: for each of cc, atype, bb returns
ele.placements when ele is a DTensor with a placements attribute, otherwise
returns a default placement (dist.Replicate()); treat a None bb as the default
Replicate placement as well. Ensure the final in_placements is a list of
placement objects (no Nones) before passing to dist.local_map.

Comment on lines +9 to +44
def aggregate(
data: paddle.Tensor,
owners: paddle.Tensor,
average: bool = True,
num_owner: Optional[int] = None,
) -> paddle.Tensor:
"""
Aggregate rows in data by specifying the owners.

Parameters
----------
data : data tensor to aggregate [n_row, feature_dim]
owners : specify the owner of each row [n_row, 1]
average : if True, average the rows, if False, sum the rows.
Default = True
num_owner : the number of owners, this is needed if the
max idx of owner is not presented in owners tensor
Default = None

Returns
-------
output: [num_owner, feature_dim]
"""
bin_count = paddle.bincount(owners)
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))

if (num_owner is not None) and (bin_count.shape[0] != num_owner):
difference = num_owner - bin_count.shape[0]
bin_count = paddle.concat([bin_count, paddle.ones_like(difference)])

# make sure this operation is done on the same device of data and owners
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
output = output.index_add_(owners, 0, data)
if average:
output = (output.T / bin_count).T
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential division by zero issue in aggregate function

When num_owner is provided and differs from bin_count.shape[0], Line 37 creates a tensor with value difference instead of creating a tensor of ones with size difference. This will result in incorrect behavior when padding the bin_count.

Apply this diff to fix the tensor creation:

-        bin_count = paddle.concat([bin_count, paddle.ones_like(difference)])
+        bin_count = paddle.concat([bin_count, paddle.ones([difference], dtype=bin_count.dtype)])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def aggregate(
data: paddle.Tensor,
owners: paddle.Tensor,
average: bool = True,
num_owner: Optional[int] = None,
) -> paddle.Tensor:
"""
Aggregate rows in data by specifying the owners.
Parameters
----------
data : data tensor to aggregate [n_row, feature_dim]
owners : specify the owner of each row [n_row, 1]
average : if True, average the rows, if False, sum the rows.
Default = True
num_owner : the number of owners, this is needed if the
max idx of owner is not presented in owners tensor
Default = None
Returns
-------
output: [num_owner, feature_dim]
"""
bin_count = paddle.bincount(owners)
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
difference = num_owner - bin_count.shape[0]
bin_count = paddle.concat([bin_count, paddle.ones_like(difference)])
# make sure this operation is done on the same device of data and owners
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
output = output.index_add_(owners, 0, data)
if average:
output = (output.T / bin_count).T
return output
def aggregate(
data: paddle.Tensor,
owners: paddle.Tensor,
average: bool = True,
num_owner: Optional[int] = None,
) -> paddle.Tensor:
"""
Aggregate rows in data by specifying the owners.
Parameters
----------
data : data tensor to aggregate [n_row, feature_dim]
owners : specify the owner of each row [n_row, 1]
average : if True, average the rows, if False, sum the rows.
Default = True
num_owner : the number of owners, this is needed if the
max idx of owner is not presented in owners tensor
Default = None
Returns
-------
output: [num_owner, feature_dim]
"""
bin_count = paddle.bincount(owners)
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
difference = num_owner - bin_count.shape[0]
bin_count = paddle.concat([bin_count, paddle.ones([difference], dtype=bin_count.dtype)])
# make sure this operation is done on the same device of data and owners
output = paddle.zeros([bin_count.shape[0], data.shape[1]])
output = output.index_add_(owners, 0, data)
if average:
output = (output.T / bin_count).T
return output
🤖 Prompt for AI Agents
In deepmd/pd/model/network/utils.py around lines 9 to 44, the padding branch for
bin_count uses paddle.ones_like(difference) where difference is an integer
scalar, producing the wrong tensor shape and causing incorrect bin_count
padding; replace that with creating a ones tensor of length difference matching
bin_count's dtype and device (e.g., paddle.ones([difference],
dtype=bin_count.dtype) placed on the same device as bin_count) and concatenate
that to bin_count so the padded bin_count has the intended size.

Comment on lines +30 to +31
LOCAL_RANK = os.environ.get("PADDLE_LOCAL_RANK")
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Use ParallelEnv/device_id and robust fallbacks; PADDLE_LOCAL_RANK isn’t standard.

Deriving LOCAL_RANK from PADDLE_LOCAL_RANK risks ValueError and may be unset with Paddle’s launcher. Prefer Paddle’s ParallelEnv.device_id (backed by FLAGS_selected_gpus), and fall back to rank/visible devices if needed. This avoids wrong device mapping in multi-GPU runs.

Apply:

-LOCAL_RANK = os.environ.get("PADDLE_LOCAL_RANK")
-LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
+try:
+    # Prefer Paddle’s ParallelEnv when available
+    _pe = paddle.distributed.ParallelEnv()
+    LOCAL_RANK = int(getattr(_pe, "device_id", 0))
+except Exception:
+    # Fallback to env var (best effort) then 0
+    _lr = os.environ.get("PADDLE_LOCAL_RANK")
+    try:
+        LOCAL_RANK = int(_lr) if _lr is not None else 0
+    except (TypeError, ValueError):
+        LOCAL_RANK = 0

Reference: ParallelEnv.device_id is sourced from FLAGS_selected_gpus and is the recommended way to select the current GPU in Paddle jobs. (paddlepaddle.org.cn)

Optional: If you really need the global rank (not local), use ParallelEnv.rank (PADDLE_TRAINER_ID). (paddlepaddle.org.cn)


🏁 Script executed:

#!/bin/bash
# Locate the environment utility file
fd env.py

# Show lines around PADDLE_LOCAL_RANK usage
rg -n "PADDLE_LOCAL_RANK" -C3 deepmd/pd/utils/env.py

# Display the top of env.py to check imports
sed -n '1,50p' deepmd/pd/utils/env.py

# Search entire repo for any existing ParallelEnv references
rg -n "ParallelEnv" -C3 .

Length of output: 1938


Refactor GPU device selection to use ParallelEnv.device_id with robust fallback

Paddle’s distributed API provides ParallelEnv.device_id to select the GPU index from FLAGS_selected_gpus; relying on the non-standard PADDLE_LOCAL_RANK env var can misalign device mapping in multi-GPU runs (paddlepaddle.org.cn).
Older CPU-only builds or mismatched Paddle versions may not initialize the internal _device_id, causing an AttributeError; wrapping ParallelEnv in a try/except ensures a safe fallback to the environment var (with ValueError handling) or to 0 (github.com).

Targets deepmd/pd/utils/env.py (lines 30–31):

-LOCAL_RANK = os.environ.get("PADDLE_LOCAL_RANK")
-LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
+try:
+    # Preferred Paddle API for GPU device index
+    from paddle.distributed import ParallelEnv
+    _pe = ParallelEnv()
+    LOCAL_RANK = int(getattr(_pe, "device_id", 0))
+except Exception:
+    # Fallback to env var (best effort) then default to 0
+    _lr = os.environ.get("PADDLE_LOCAL_RANK")
+    try:
+        LOCAL_RANK = int(_lr) if _lr is not None else 0
+    except (TypeError, ValueError):
+        LOCAL_RANK = 0

Optional: for global rank use paddle.distributed.get_rank() or ParallelEnv.rank (which maps to PADDLE_TRAINER_ID) (paddlepaddle.org.cn).

🤖 Prompt for AI Agents
In deepmd/pd/utils/env.py around lines 30-31, replace the direct use of
PADDLE_LOCAL_RANK env var with Paddle's ParallelEnv.device_id and add robust
fallbacks: try to import paddle.distributed and get device_id via
paddle.distributed.ParallelEnv().device_id (or use
paddle.distributed.get_rank()/ParallelEnv.rank if you also need global rank); if
ParallelEnv is missing or accessing device_id raises AttributeError, fall back
to reading os.environ["PADDLE_LOCAL_RANK"] with int conversion inside a
try/except ValueError, and if that fails default to 0; ensure imports are added
and exceptions are caught so CPU-only or older Paddle builds won’t crash.

Comment on lines +20 to +29
def compute_exp_sw(distance, rmin: float, rmax: float):
"""Compute the exponential switch function for neighbor update."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
distance = paddle.clip(distance, min=0.0, max=rmax)
C = 20
a = C / rmin
b = rmin
exp_sw = paddle.exp(-paddle.exp(a * (distance - b)))
return exp_sw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden compute_exp_sw: validate rmin > 0 and consider exposing the steepness constant.

Division by zero occurs when rmin == 0; currently only rmin < rmax is enforced. Also, C=20 controls steepness but is hard-coded. Consider making it a parameter with a default.

-def compute_exp_sw(distance, rmin: float, rmax: float):
+def compute_exp_sw(distance, rmin: float, rmax: float, C: float = 20.0):
     """Compute the exponential switch function for neighbor update."""
-    if rmin >= rmax:
-        raise ValueError("rmin should be less than rmax.")
+    if rmin >= rmax:
+        raise ValueError("rmin should be less than rmax.")
+    if rmin <= 0.0:
+        raise ValueError("rmin must be positive.")
     distance = paddle.clip(distance, min=0.0, max=rmax)
-    C = 20
     a = C / rmin
     b = rmin
     exp_sw = paddle.exp(-paddle.exp(a * (distance - b)))
     return exp_sw
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def compute_exp_sw(distance, rmin: float, rmax: float):
"""Compute the exponential switch function for neighbor update."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
distance = paddle.clip(distance, min=0.0, max=rmax)
C = 20
a = C / rmin
b = rmin
exp_sw = paddle.exp(-paddle.exp(a * (distance - b)))
return exp_sw
def compute_exp_sw(distance, rmin: float, rmax: float, C: float = 20.0):
"""Compute the exponential switch function for neighbor update."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
if rmin <= 0.0:
raise ValueError("rmin must be positive.")
distance = paddle.clip(distance, min=0.0, max=rmax)
a = C / rmin
b = rmin
exp_sw = paddle.exp(-paddle.exp(a * (distance - b)))
return exp_sw
🤖 Prompt for AI Agents
In deepmd/pd/utils/preprocess.py around lines 20 to 29, the function
compute_exp_sw can divide by zero when rmin == 0 and hard-codes the steepness
constant C; modify the signature to accept an optional steepness parameter
(e.g., C: float = 20.0) and validate inputs: raise ValueError if rmin <= 0 or if
rmin >= rmax, optionally also validate rmax > 0; ensure rmin and C are cast to
float before use and keep the default behavior when the new parameter is not
provided.

@njzjz njzjz changed the base branch from master to devel August 21, 2025 11:13
@waliwali777 waliwali777 force-pushed the test_paral_auto_improve branch from 09dc8dc to 95c1377 Compare August 22, 2025 07:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pd/train/training.py (1)

1119-1127: box stays as NumPy, but later you call dist.shard_tensor on it.

This mismatch is the source of a likely runtime error (see Lines 775-780). Either:

  • Move box to DEVICE like other tensors, or
  • Don’t attempt to shard/DTensor it; keep it replicated.

Also, iterate dicts directly rather than calling .keys().

-        for key in batch_data.keys():
-            if key == "sid" or key == "fid" or key == "box" or "find_" in key:
+        for key in batch_data:
+            if key == "sid" or key == "fid" or "find_" in key:
                 continue
             elif not isinstance(batch_data[key], list):
                 if batch_data[key] is not None:
                     batch_data[key] = batch_data[key].to(DEVICE, blocking=False)

If you prefer to keep box on CPU, then adjust the sharding block to skip non-Tensors (as suggested earlier) and pass box through unchanged.

🧹 Nitpick comments (7)
deepmd/pd/train/training.py (7)

22-31: Clean up duplicate/unused imports (Ruff: F401/F811).

  • hybrid_parallel_util as hpu is unused.
  • dist, fleet, and functools are re-imported below and shadow earlier imports.

Remove these to avoid confusion and satisfy static analysis.

-from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
-import paddle.distributed as dist
-from paddle.distributed import fleet
-import functools

325-337: No-op: summary printing commented out.

Acknowledged. If this is temporary while samplers are disabled, consider a short log stating that summary is skipped because no sampler is used.


371-391: No-op: multi-task summary printing commented out.

Same note as above. Fine to skip while samplers are disabled; consider lightweight logging for operator clarity.


616-647: CINN input_spec: dynamic label spec is good, but tighten types and minor simplification.

  • Using label_dict.keys() in the dict comprehension is unnecessary; iterate the dict directly. (Ruff SIM118)
  • The find_* flags are set as np.float32 scalars. It’s safer to declare them as scalar InputSpec([], "float32", ...) so to_static has full type info.
  • Consider guarding this probe call to self.get_data(is_train=True) with a try/except and reuse the same sample for subsequent steps to avoid advancing the iterator unexpectedly during init.
-            label_dict_spec = {
-                k: spec_templates[k] for k in label_dict.keys() if k in spec_templates
-            }
+            label_dict_spec = {k: spec_templates[k] for k in label_dict if k in spec_templates}
-            spec_templates = {
-                "find_box": np.float32(1.0),
-                "find_coord": np.float32(1.0),
-                "find_numb_copy": np.float32(0.0),
+            spec_templates = {
+                "find_box": static.InputSpec([], "float32", name="find_box"),
+                "find_coord": static.InputSpec([], "float32", name="find_coord"),
+                "find_numb_copy": static.InputSpec([], "float32", name="find_numb_copy"),
                 "numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
                 "find_energy": np.float32(1.0),
                 "energy": static.InputSpec([1, 1], "float64", name="energy"),
                 "find_force": np.float32(1.0),
                 "force": static.InputSpec([1, -1, 3], "float64", name="force"),
                 "find_virial": np.float32(0.0),
                 "virial": static.InputSpec([1, 9], "float64", name="virial"),
                 "natoms": static.InputSpec([1, -1], "int32", name="natoms"),
             }

If find_* are strictly Python scalars on purpose, keep as is but please confirm that Paddle’s to_static reliably handles non-Tensor dict entries across versions.


661-668: DDP wrapping mixed with Auto Parallel mesh/sharding—verify compatibility.

You initialize a mesh (Lines 106-110) and shard inputs (Lines 775-780) but also wrap the model/optimizer with fleet.distributed_model/fleet.distributed_optimizer (DDP). Mixing auto-sharding with DDP can lead to undefined behavior unless explicitly supported in your Paddle version.

If the intent is strictly data parallel:

  • Don’t create an auto-parallel mesh or call dist.shard_tensor; rely on fleet.distributed_model only.

If the intent is auto parallel/tensor parallel:

  • Use auto-parallel end-to-end (program pass + compiler) and avoid wrapping with DDP.

I can prepare two minimal patches (DP-only and Auto-Parallel-only) once you confirm the intended mode.


751-756: Unused sync_context; remove or re-enable no_sync micro-batching.

The variable is assigned but never used. Either delete it or restore the intended with sync_context(): block.

-                sync_context = (
-                    self.wrapper.no_sync
-                    if self.world_size > 1
-                    else contextlib.nullcontext
-                )
+                # Use no_sync for gradient accumulation if needed:
+                # sync_context = self.wrapper.no_sync if self.world_size > 1 else contextlib.nullcontext

If you don’t plan to bring back the with sync_context(): blocks, also remove the contextlib import at Line 2.


855-860: Validation disabled for single-task; confirm this is intentional and/or gate via config.

Skipping validation entirely can mask regressions. Consider gating with a config flag (e.g., training.skip_validation) and default to the previous behavior.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 09dc8dc and 95c1377.

📒 Files selected for processing (4)
  • deepmd/pd/loss/ener.py (4 hunks)
  • deepmd/pd/train/training.py (10 hunks)
  • deepmd/pd/utils/dataloader.py (1 hunks)
  • examples/water/dpa3/input_torch.json (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/pd/loss/ener.py
  • examples/water/dpa3/input_torch.json
  • deepmd/pd/utils/dataloader.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pd/train/training.py (4)
deepmd/utils/data_system.py (1)
  • get_data (764-810)
deepmd/pt/train/training.py (1)
  • get_data (1096-1138)
source/tests/pd/model/test_saveload_dpa1.py (1)
  • get_data (117-134)
deepmd/pd/utils/utils.py (1)
  • nvprof_context (357-366)
🪛 Ruff (0.12.2)
deepmd/pd/train/training.py

22-22: paddle.distributed.fleet.utils.hybrid_parallel_util imported but unused

Remove unused import: paddle.distributed.fleet.utils.hybrid_parallel_util

(F401)


29-29: Redefinition of unused dist from line 18

Remove definition: dist

(F811)


30-30: Redefinition of unused fleet from line 20

Remove definition: fleet

(F811)


30-30: paddle.distributed.fleet imported but unused

Remove unused import: paddle.distributed.fleet

(F401)


31-31: Redefinition of unused functools from line 3

Remove definition: functools

(F811)


106-106: Redefinition of unused fleet from line 30

Remove definition: fleet

(F811)


636-636: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


751-751: Local variable sync_context is assigned to but never used

Remove assignment to unused variable sync_context

(F841)

🔇 Additional comments (2)
deepmd/pd/train/training.py (2)

1153-1183: Header printing logic adapts to absent validation—looks good.

The header toggles between train-only and train+val properly depending on valid_results. No issues here.


172-181: Clarify sharding is on the atomic dimension, not the batch dimension
The loader is intentionally fixed at batch_size=1 with collate_fn=lambda batch: batch[0] so each iteration yields a single molecular graph of shape [num_atoms, …]. Those tensors (coord, atype, box, etc.) are then split across devices along their atomic/node axis via dist.shard_tensor(..., placements=[Shard(0)]). There is no batching over multiple graphs, so sharding the first dimension never refers to a “batch” of size 1—it refers to splitting the atoms themselves. As a result:

  • The reviewer’s concern about “batch-size not divisible by world size” does not apply, since there is no multi-sample batch to shard.
  • Restoring multi-sample batching would break variable-size graph handling.
  • The existing pattern correctly enables hybrid parallelism over graph atoms rather than sample batches.

Likely an incorrect or invalid review comment.

Comment on lines +106 to +110
from paddle.distributed import fleet
mesh_dims = [("dp", 32)]
fleet.auto.create_mesh(mesh_dims)
fleet.init(is_collective=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Hard-coded mesh size (32) and unconditional Fleet init can break single-GPU/mismatched world sizes.

  • Using mesh_dims = [("dp", 32)] will fail when world_size != 32.
  • Calling fleet.init(is_collective=True) and creating a mesh unconditionally can interfere with non-distributed runs or environments that expect the caller to initialize Fleet.
  • Later, you use dist.get_mesh() for sharding. If the mesh size doesn’t match the actual world_size, sharding will error at runtime.

Make mesh creation conditional and derive its size from the actual world size. Also, avoid re-importing fleet here (it’s already imported at the module top).

-        from paddle.distributed import fleet
-        mesh_dims = [("dp", 32)]
-        fleet.auto.create_mesh(mesh_dims)
-        fleet.init(is_collective=True)
+        # Initialize Fleet/mesh only when running in distributed mode.
+        if dist.is_available():
+            if not dist.is_initialized():
+                fleet.init(is_collective=True)
+            ws = dist.get_world_size() if dist.is_initialized() else 1
+            # Make mesh size match the actual world size
+            mesh_dims = [("dp", ws)]
+            # Create mesh only once (get_mesh may not exist in older versions)
+            try:
+                mesh = dist.get_mesh()
+            except Exception:
+                mesh = None
+            if mesh is None:
+                fleet.auto.create_mesh(mesh_dims)

Optionally move this block to after Lines 131-137 (once self.world_size is set) to reuse that value and keep initialization logic together. Would you like me to prepare that refactor?

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from paddle.distributed import fleet
mesh_dims = [("dp", 32)]
fleet.auto.create_mesh(mesh_dims)
fleet.init(is_collective=True)
# Initialize Fleet/mesh only when running in distributed mode.
if dist.is_available():
if not dist.is_initialized():
fleet.init(is_collective=True)
ws = dist.get_world_size() if dist.is_initialized() else 1
# Make mesh size match the actual world size
mesh_dims = [("dp", ws)]
# Create mesh only once (get_mesh may not exist in older versions)
try:
mesh = dist.get_mesh()
except Exception:
mesh = None
if mesh is None:
fleet.auto.create_mesh(mesh_dims)
🧰 Tools
🪛 Ruff (0.12.2)

106-106: Redefinition of unused fleet from line 30

Remove definition: fleet

(F811)

🤖 Prompt for AI Agents
In deepmd/pd/train/training.py around lines 106 to 110, the code unconditionally
imports/creates a fleet mesh with a hard-coded size of 32 and calls
fleet.init(is_collective=True), which will fail for single-GPU or mismatched
world sizes and can interfere with non-distributed runs; instead, remove the
redundant local import of fleet, make mesh creation conditional (only when
distributed mode and fleet not already initialized), derive the dp size from the
actual world_size (e.g., use self.world_size or dist.get_world_size()) rather
than the hard-coded 32, avoid calling fleet.init here unless initialization is
required and not yet done by the caller, and consider moving this block to after
self.world_size is set (lines ~131-137) so you can reuse that value and keep
initialization logic centralized.

Comment on lines +775 to 780
for __key in ('coord', 'atype', 'box'):
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
for __key, _ in label_dict.items():
if isinstance(label_dict[__key], paddle.Tensor):
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
model_pred, loss, more_loss = self.wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Forward sharding will fail with batch_size=1 and non-Tensor 'box'.

  • With the current DataLoader (batch_size=1), coord/atype shapes are [1, ...]. Sharding along dim 0 across world_size>1 will error.
  • get_data() explicitly skips moving box to device (Line 1120), so box may be a NumPy array. dist.shard_tensor expects a Paddle Tensor.

Guard sharding, replicate box, and only shard when the first dimension is divisible by world_size.

-                    for __key in ('coord', 'atype', 'box'):
-                        input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
-                    for __key, _ in label_dict.items():
-                        if isinstance(label_dict[__key], paddle.Tensor):
-                            label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
+                    if dist.is_available() and dist.is_initialized():
+                        mesh = dist.get_mesh()
+                        ws = dist.get_world_size()
+                        # Determine batch size if tensor is present
+                        bsz = (
+                            int(input_dict["coord"].shape[0])
+                            if isinstance(input_dict.get("coord"), paddle.Tensor)
+                            else 1
+                        )
+                        # Shard on batch dim only if divisible; otherwise replicate
+                        batch_placements = [dist.Shard(0)] if (bsz % ws == 0 and bsz > 0) else [dist.Replicate()]
+                        for __key in ("coord", "atype"):
+                            if isinstance(input_dict.get(__key), paddle.Tensor):
+                                input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=mesh, placements=batch_placements)
+                        # Box is per-batch metadata (shape [1, 9]); replicate it
+                        if isinstance(input_dict.get("box"), paddle.Tensor):
+                            input_dict["box"] = dist.shard_tensor(input_dict["box"], mesh=mesh, placements=[dist.Replicate()])
+                        # Labels: shard only tensors; flags/scalars remain Python/NumPy
+                        for __key, __val in label_dict.items():
+                            if isinstance(__val, paddle.Tensor):
+                                label_dict[__key] = dist.shard_tensor(__val, mesh=mesh, placements=batch_placements)

Alternative: if you intend to shard along the atom dimension instead, we’ll need to set placements to Shard(1) for coord/atype/force and replicate box/energy. That requires ensuring the atom dimension is divisible by world_size. Say the word and I’ll prep that patch.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for __key in ('coord', 'atype', 'box'):
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
for __key, _ in label_dict.items():
if isinstance(label_dict[__key], paddle.Tensor):
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
model_pred, loss, more_loss = self.wrapper(
if dist.is_available() and dist.is_initialized():
mesh = dist.get_mesh()
ws = dist.get_world_size()
# Determine batch size if tensor is present
bsz = (
int(input_dict["coord"].shape[0])
if isinstance(input_dict.get("coord"), paddle.Tensor)
else 1
)
# Shard on batch dim only if divisible; otherwise replicate
batch_placements = [dist.Shard(0)] if (bsz % ws == 0 and bsz > 0) else [dist.Replicate()]
for __key in ("coord", "atype"):
if isinstance(input_dict.get(__key), paddle.Tensor):
input_dict[__key] = dist.shard_tensor(
input_dict[__key],
mesh=mesh,
placements=batch_placements,
)
# Box is per-batch metadata; always replicate if it's a Tensor
if isinstance(input_dict.get("box"), paddle.Tensor):
input_dict["box"] = dist.shard_tensor(
input_dict["box"],
mesh=mesh,
placements=[dist.Replicate()],
)
# Labels: shard only tensors; leave other types untouched
for __key, __val in label_dict.items():
if isinstance(__val, paddle.Tensor):
label_dict[__key] = dist.shard_tensor(
__val,
mesh=mesh,
placements=batch_placements,
)
model_pred, loss, more_loss = self.wrapper(
🤖 Prompt for AI Agents
In deepmd/pd/train/training.py around lines 775 to 780, the code unconditionally
shards coord/atype/box and label tensors which fails when batch_size==1 or when
box is a NumPy array; change the logic to: check dist.get_world_size() and the
tensor's first-dimension size before sharding (only call dist.shard_tensor if
the object is a paddle.Tensor and tensor.shape[0] % world_size == 0 and
tensor.shape[0] > 1), otherwise replicate using dist.broadcast or set placements
to dist.get_replicate(); for box specifically, if it's not a paddle.Tensor
convert it to one on the correct device or directly replicate it instead of
sharding; apply the same guarded check when iterating label_dict entries so
non-Tensor values are replicated/left unchanged and only valid tensors with
divisible first-dim are sharded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants