From 81182318611d3e65b716b61b78bb6daa73d17c14 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Jun 2025 21:31:24 +0800 Subject: [PATCH 01/13] update adaptive CINN --- deepmd/pd/train/training.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index d72c270667..07cd738a1c 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -607,6 +607,35 @@ def warm_up_linear(step, warmup_steps): ) backend = "CINN" if CINN else None + + # NOTE: This is a trick to decide the right input_spec for wrapper.forward + _, label_dict, _ = self.get_data(is_train=True, task_key="Default") + label_dict_spec = { + "find_box": np.float32(1.0), + "find_coord": np.float32(1.0), + "find_numb_copy": np.float32(0.0), + "numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"), + "find_energy": np.float32(1.0), + "energy": static.InputSpec([1, 1], "float64", name="energy"), + "find_force": np.float32(1.0), + "force": static.InputSpec([1, -1, 3], "float64", name="force"), + "find_virial": np.float32(0.0), + "virial": static.InputSpec([1, 9], "float64", name="virial"), + "natoms": static.InputSpec([1, -1], "int32", name="natoms"), + } + if "virial" not in label_dict: + label_dict_spec.pop("virial") + if "find_virial" not in label_dict: + label_dict_spec.pop("find_virial") + if "energy" not in label_dict: + label_dict_spec.pop("energy") + if "find_energy" not in label_dict: + label_dict_spec.pop("find_energy") + if "force" not in label_dict: + label_dict_spec.pop("force") + if "find_force" not in label_dict: + label_dict_spec.pop("find_force") + self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ @@ -615,19 +644,7 @@ def warm_up_linear(step, warmup_steps): None, # spin static.InputSpec([1, 9], "float64", name="box"), # box static.InputSpec([], "float64", name="cur_lr"), # cur_lr - { - "find_box": np.float32(1.0), - "find_coord": np.float32(1.0), - "find_numb_copy": np.float32(0.0), - "numb_copy": static.InputSpec( - [1, 1], "int64", name="numb_copy" - ), - "find_energy": np.float32(1.0), - "energy": static.InputSpec([1, 1], "float64", name="energy"), - "find_force": np.float32(1.0), - "force": static.InputSpec([1, -1, 3], "float64", name="force"), - "natoms": static.InputSpec([1, -1], "int32", name="natoms"), - }, # label, + label_dict_spec, # label, # None, # task_key # False, # inference_only # False, # do_atomic_virial From 29bedd3ddf91d304eab2f52762e38edc3931b479 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Jun 2025 02:07:47 +0800 Subject: [PATCH 02/13] support loc_mapping in pd bkd --- .pre-commit-config.yaml | 52 +- deepmd/pd/entrypoints/main.py | 4 + deepmd/pd/model/descriptor/dpa3.py | 14 +- deepmd/pd/model/descriptor/env_mat.py | 14 +- deepmd/pd/model/descriptor/repflow_layer.py | 506 +++++++++++++++++--- deepmd/pd/model/descriptor/repflows.py | 174 +++++-- deepmd/pd/model/descriptor/repformers.py | 8 +- deepmd/pd/utils/env.py | 2 + deepmd/pd/utils/preprocess.py | 21 +- deepmd/pd/utils/spin.py | 1 - deepmd/pd/utils/utils.py | 128 ++++- 11 files changed, 769 insertions(+), 155 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cee3d7f2ce..77dab6f3aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$) # markdown, yaml, CSS, javascript - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - types_or: [markdown, yaml, css] - # workflow files cannot be modified by pre-commit.ci - exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + # - repo: https://github.com/pre-commit/mirrors-prettier + # rev: v4.0.0-alpha.8 + # hooks: + # - id: prettier + # types_or: [markdown, yaml, css] + # # workflow files cannot be modified by pre-commit.ci + # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.11.0-1 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - - repo: https://github.com/njzjz/mirrors-bibtex-tidy - rev: v1.13.0 - hooks: - - id: bibtex-tidy - args: - - --curly - - --numeric - - --align=13 - - --blank-lines - # disable sort: the order of keys and fields has explict meanings - #- --sort=key - - --duplicates=key,doi,citation,abstract - - --merge=combine - #- --sort-fields - #- --strip-comments - - --trailing-commas - - --encode-urls - - --remove-empty-fields - - --wrap=80 + # - repo: https://github.com/njzjz/mirrors-bibtex-tidy + # rev: v1.13.0 + # hooks: + # - id: bibtex-tidy + # args: + # - --curly + # - --numeric + # - --align=13 + # - --blank-lines + # # disable sort: the order of keys and fields has explict meanings + # #- --sort=key + # - --duplicates=key,doi,citation,abstract + # - --merge=combine + # #- --sort-fields + # #- --strip-comments + # - --trailing-commas + # - --encode-urls + # - --remove-empty-fields + # - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 diff --git a/deepmd/pd/entrypoints/main.py b/deepmd/pd/entrypoints/main.py index 8d96c4e6f2..c0a7cb769a 100644 --- a/deepmd/pd/entrypoints/main.py +++ b/deepmd/pd/entrypoints/main.py @@ -41,6 +41,9 @@ from deepmd.pd.train.wrapper import ( ModelWrapper, ) +from deepmd.pd.utils import ( + env, +) from deepmd.pd.utils.dataloader import ( DpLoaderSet, ) @@ -233,6 +236,7 @@ def train( output: str = "out.json", ) -> None: log.info("Configuration path: %s", input_file) + env.CUSTOM_OP_USE_JIT = False if LOCAL_RANK == 0: SummaryPrinter()() with open(input_file) as fin: diff --git a/deepmd/pd/model/descriptor/dpa3.py b/deepmd/pd/model/descriptor/dpa3.py index 0f1a8f4c2f..eed4d8a385 100644 --- a/deepmd/pd/model/descriptor/dpa3.py +++ b/deepmd/pd/model/descriptor/dpa3.py @@ -91,7 +91,7 @@ class DescrptDPA3(BaseDescriptor, paddle.nn.Layer): Whether to use bias in the type embedding layer. use_loc_mapping : bool, Optional Whether to use local atom index mapping in training or non-parallel inference. - Not supported yet in Paddle. + When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation. type_map : list[str], Optional A list of strings. Give the name to each type of atoms. @@ -117,7 +117,7 @@ def __init__( seed: Optional[Union[int, list[int]]] = None, use_econf_tebd: bool = False, use_tebd_bias: bool = False, - use_loc_mapping: bool = False, + use_loc_mapping: bool = True, type_map: Optional[list[str]] = None, ) -> None: super().__init__() @@ -160,6 +160,8 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + edge_init_use_dist=self.repflow_args.edge_init_use_dist, + use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, use_loc_mapping=use_loc_mapping, @@ -170,8 +172,8 @@ def init_subclass_params(sub_data, sub_class): ) self.use_econf_tebd = use_econf_tebd - self.use_tebd_bias = use_tebd_bias self.use_loc_mapping = use_loc_mapping + self.use_tebd_bias = use_tebd_bias self.type_map = type_map self.tebd_dim = self.repflow_args.n_dim self.type_embedding = TypeEmbedNet( @@ -487,12 +489,16 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + parallel_mode = comm_dict is not None # cast the input to internal precsion extended_coord = extended_coord.to(dtype=self.prec) nframes, nloc, nnei = nlist.shape nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 - node_ebd_ext = self.type_embedding(extended_atype) + if not parallel_mode and self.use_loc_mapping: + node_ebd_ext = self.type_embedding(extended_atype[:, :nloc]) + else: + node_ebd_ext = self.type_embedding(extended_atype) node_ebd_inp = node_ebd_ext[:, :nloc, :] # repflows node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( diff --git a/deepmd/pd/model/descriptor/env_mat.py b/deepmd/pd/model/descriptor/env_mat.py index 9b72da0b16..214fb593de 100644 --- a/deepmd/pd/model/descriptor/env_mat.py +++ b/deepmd/pd/model/descriptor/env_mat.py @@ -3,6 +3,7 @@ import paddle from deepmd.pd.utils.preprocess import ( + compute_exp_sw, compute_smooth_weight, ) @@ -14,6 +15,7 @@ def _make_env_mat( ruct_smth: float, radial_only: bool = False, protection: float = 0.0, + use_exp_switch: bool = False, ): """Make smooth environment matrix.""" bsz, natoms, nnei = nlist.shape @@ -24,7 +26,8 @@ def _make_env_mat( nlist = paddle.where(mask, nlist, nall - 1) coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3]) index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3]) - coord_r = paddle.take_along_axis(coord, axis=1, indices=index) + coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1) + coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index) coord_r = coord_r.reshape([bsz, natoms, nnei, 3]) diff = coord_r - coord_l length = paddle.linalg.norm(diff, axis=-1, keepdim=True) @@ -32,7 +35,11 @@ def _make_env_mat( length = length + (~mask.unsqueeze(-1)).astype(length.dtype) t0 = 1 / (length + protection) t1 = diff / (length + protection) ** 2 - weight = compute_smooth_weight(length, ruct_smth, rcut) + weight = ( + compute_smooth_weight(length, ruct_smth, rcut) + if not use_exp_switch + else compute_exp_sw(length, ruct_smth, rcut) + ) weight = weight * mask.unsqueeze(-1).astype(weight.dtype) if radial_only: env_mat = t0 * weight @@ -51,6 +58,7 @@ def prod_env_mat( rcut_smth: float, radial_only: bool = False, protection: float = 0.0, + use_exp_switch: bool = False, ): """Generate smooth environment matrix from atom coordinates and other context. @@ -63,6 +71,7 @@ def prod_env_mat( - rcut_smth: Smooth hyper-parameter for pair force & energy. - radial_only: Whether to return a full description or a radial-only descriptor. - protection: Protection parameter to prevent division by zero errors during calculations. + - use_exp_switch: Whether to use the exponential switch function. Returns ------- @@ -75,6 +84,7 @@ def prod_env_mat( rcut_smth, radial_only, protection=protection, + use_exp_switch=use_exp_switch, ) # shape [n_atom, dim, 4 or 1] t_avg = mean[atype] # [n_atom, dim, 4 or 1] t_std = stddev[atype] # [n_atom, dim, 4 or 1] diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index f1bdd0439d..78cfa7a56d 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -19,6 +19,9 @@ from deepmd.pd.model.network.mlp import ( MLPLayer, ) +from deepmd.pd.model.network.utils import ( + aggregate, +) from deepmd.pd.utils.env import ( PRECISION_DICT, ) @@ -326,6 +329,61 @@ def _cal_hg( h2g2 = paddle.matmul(paddle.matrix_transpose(h2), edge_ebd) * invnnei return h2g2 + @staticmethod + def _cal_hg_dynamic( + flat_edge_ebd: paddle.Tensor, + flat_h2: paddle.Tensor, + flat_sw: paddle.Tensor, + owner: paddle.Tensor, + num_owner: int, + nb: int, + nloc: int, + scale_factor: float, + ) -> paddle.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + flat_edge_ebd + Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim. + flat_h2 + Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3. + flat_sw + Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape n_edge. + owner + The owner index of the neighbor to reduce on. + num_owner : int + The total number of the owner. + nb : int + The number of batches. + nloc : int + The number of local atoms. + scale_factor : float + The scale factor to apply after reduce. + + Returns + ------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x e_dim. + """ + n_edge, e_dim = flat_edge_ebd.shape + # n_edge x e_dim + flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1) + # n_edge x 3 x e_dim + flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape( + [-1, 3 * e_dim] + ) + # nf x nloc x 3 x e_dim + h2g2 = ( + aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape( + [nb, nloc, 3, e_dim] + ) + * scale_factor + ) + return h2g2 + @staticmethod def _cal_grrg(h2g2: paddle.Tensor, axis_neuron: int) -> paddle.Tensor: """ @@ -398,6 +456,63 @@ def symmetrization_op( g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13 + def symmetrization_op_dynamic( + self, + flat_edge_ebd: paddle.Tensor, + flat_h2: paddle.Tensor, + flat_sw: paddle.Tensor, + owner: paddle.Tensor, + num_owner: int, + nb: int, + nloc: int, + scale_factor: float, + axis_neuron: int, + ) -> paddle.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + flat_edge_ebd + Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim. + flat_h2 + Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3. + flat_sw + Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape n_edge. + owner + The owner index of the neighbor to reduce on. + num_owner : int + The total number of the owner. + nb : int + The number of batches. + nloc : int + The number of local atoms. + scale_factor : float + The scale factor to apply after reduce. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # nb x nloc x 3 x e_dim + h2g2 = self._cal_hg_dynamic( + flat_edge_ebd, + flat_h2, + flat_sw, + owner, + num_owner, + nb, + nloc, + scale_factor, + ) + # nb x nloc x (axis x e_dim) + grrg = self._cal_grrg(h2g2, axis_neuron) + return grrg + def optim_angle_update( self, angle_ebd: paddle.Tensor, @@ -419,7 +534,7 @@ def optim_angle_update( node_dim = node_ebd.shape[-1] edge_dim = edge_ebd.shape[-1] # angle_dim, node_dim, edge_dim, edge_dim - sub_angle, sub_node, sub_edge_ij, sub_edge_ik = paddle.split( + sub_angle, sub_node, sub_edge_ik, sub_edge_ij = paddle.split( matrix, [angle_dim, node_dim, edge_dim, edge_dim] ) @@ -428,14 +543,64 @@ def optim_angle_update( # nf * nloc * angle_dim sub_node_update = paddle.matmul(node_ebd, sub_node) # nf * nloc * a_nnei * angle_dim - sub_edge_update_ij = paddle.matmul(edge_ebd, sub_edge_ij) sub_edge_update_ik = paddle.matmul(edge_ebd, sub_edge_ik) + sub_edge_update_ij = paddle.matmul(edge_ebd, sub_edge_ij) result_update = ( bias + sub_node_update.unsqueeze(2).unsqueeze(3) - + sub_edge_update_ij.unsqueeze(2) - + sub_edge_update_ik.unsqueeze(3) + + sub_edge_update_ik.unsqueeze(2) + + sub_edge_update_ij.unsqueeze(3) + + sub_angle_update + ) + return result_update + + def optim_angle_update_dynamic( + self, + flat_angle_ebd: paddle.Tensor, + node_ebd: paddle.Tensor, + flat_edge_ebd: paddle.Tensor, + n2a_index: paddle.Tensor, + eij2a_index: paddle.Tensor, + eik2a_index: paddle.Tensor, + feat: str = "edge", + ) -> paddle.Tensor: + if feat == "edge": + matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias + elif feat == "angle": + matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias + else: + raise NotImplementedError + nf, nloc, node_dim = node_ebd.shape + edge_dim = flat_edge_ebd.shape[-1] + angle_dim = flat_angle_ebd.shape[-1] + # angle_dim, node_dim, edge_dim, edge_dim + sub_angle, sub_node, sub_edge_ik, sub_edge_ij = paddle.split( + matrix, [angle_dim, node_dim, edge_dim, edge_dim] + ) + + # n_angle * angle_dim + sub_angle_update = paddle.matmul(flat_angle_ebd, sub_angle) + + # nf * nloc * angle_dim + sub_node_update = paddle.matmul(node_ebd, sub_node) + # n_angle * angle_dim + sub_node_update = paddle.index_select( + sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0 + ) + + # n_edge * angle_dim + sub_edge_update_ik = paddle.matmul(flat_edge_ebd, sub_edge_ik) + sub_edge_update_ij = paddle.matmul(flat_edge_ebd, sub_edge_ij) + # n_angle * angle_dim + sub_edge_update_ik = paddle.index_select(sub_edge_update_ik, eik2a_index, 0) + sub_edge_update_ij = paddle.index_select(sub_edge_update_ij, eij2a_index, 0) + + result_update = ( + bias + + sub_node_update + + sub_edge_update_ik + + sub_edge_update_ij + sub_angle_update ) return result_update @@ -475,9 +640,55 @@ def optim_edge_update( ) return result_update + def optim_edge_update_dynamic( + self, + node_ebd: paddle.Tensor, + node_ebd_ext: paddle.Tensor, + flat_edge_ebd: paddle.Tensor, + n2e_index: paddle.Tensor, + n_ext2e_index: paddle.Tensor, + feat: str = "node", + ) -> paddle.Tensor: + if feat == "node": + matrix, bias = self.node_edge_linear.matrix, self.node_edge_linear.bias + elif feat == "edge": + matrix, bias = self.edge_self_linear.matrix, self.edge_self_linear.bias + else: + raise NotImplementedError + assert bias is not None + nf, nall, node_dim = node_ebd_ext.shape + _, nloc, _ = node_ebd.shape + edge_dim = flat_edge_ebd.shape[-1] + # node_dim, node_dim, edge_dim + node, node_ext, edge = paddle.split(matrix, [node_dim, node_dim, edge_dim]) + + # nf * nloc * node/edge_dim + sub_node_update = paddle.matmul(node_ebd, node) + # n_edge * node/edge_dim + sub_node_update = paddle.index_select( + sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), + n2e_index, + 0, + ) + + # nf * nall * node/edge_dim + sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext) + # n_edge * node/edge_dim + sub_node_ext_update = paddle.index_select( + sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]), + n_ext2e_index, + 0, + ) + + # n_edge * node/edge_dim + sub_edge_update = paddle.matmul(flat_edge_ebd, edge) + + result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update + return result_update + def forward( self, - node_ebd_ext: paddle.Tensor, # nf x nall x n_dim + node_ebd_ext: paddle.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode edge_ebd: paddle.Tensor, # nf x nloc x nnei x e_dim h2: paddle.Tensor, # nf x nloc x nnei x 3 angle_ebd: paddle.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim @@ -487,6 +698,8 @@ def forward( a_nlist: paddle.Tensor, # nf x nloc x a_nnei a_nlist_mask: paddle.Tensor, # nf x nloc x a_nnei a_sw: paddle.Tensor, # switch func, nf x nloc x a_nnei + edge_index: paddle.Tensor, # n_edge x 2 + angle_index: paddle.Tensor, # n_angle x 3 ): """ Parameters @@ -511,6 +724,18 @@ def forward( Masks of the neighbor list for angle. real nei 1 otherwise 0 a_sw : nf x nloc x a_nnei Switch function for angle. + edge_index : Optional for dynamic sel, n_edge x 2 + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : Optional for dynamic sel, n_angle x 3 + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). Returns ------- @@ -524,12 +749,35 @@ def forward( nb, nloc, nnei, _ = edge_ebd.shape nall = node_ebd_ext.shape[1] node_ebd = node_ebd_ext[:, :nloc, :] + n_edge = int(nlist_mask.sum().item()) if paddle.in_dynamic_mode(): assert [nb, nloc] == node_ebd.shape[:2] - if paddle.in_dynamic_mode(): - assert [nb, nloc, nnei] == h2.shape[:3] + if not self.use_dynamic_sel: + if paddle.in_dynamic_mode(): + assert [nb, nloc, nnei, 3] == h2.shape + else: + if paddle.in_dynamic_mode(): + assert [n_edge, 3] == h2.shape del a_nlist # may be used in the future + n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[:, 0], + angle_index[:, 1], + angle_index[:, 2], + ) + + # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else paddle.index_select( + node_ebd_ext.reshape([-1, self.n_dim]), + n_ext2e_index, + 0, + ) + ) + n_update_list: list[paddle.Tensor] = [node_ebd] e_update_list: list[paddle.Tensor] = [edge_ebd] a_update_list: list[paddle.Tensor] = [angle_ebd] @@ -538,8 +786,6 @@ def forward( node_self_mlp = self.act(self.node_self_mlp(node_ebd)) n_update_list.append(node_self_mlp) - nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) - # node sym (grrg + drrd) node_sym_list: list[paddle.Tensor] = [] node_sym_list.append( @@ -550,6 +796,18 @@ def forward( sw, self.axis_neuron, ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) ) node_sym_list.append( self.symmetrization_op( @@ -559,20 +817,47 @@ def forward( sw, self.axis_neuron, ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) ) node_sym = self.act(self.node_sym_linear(paddle.concat(node_sym_list, axis=-1))) n_update_list.append(node_sym) if not self.optim_update: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = paddle.concat( - [ - paddle.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - axis=-1, - ) + if not self.use_dynamic_sel: + # nb x nloc x nnei x (n_dim * 2 + e_dim) + edge_info = paddle.concat( + [ + paddle.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + # n_edge x (n_dim * 2 + e_dim) + edge_info = paddle.concat( + [ + paddle.index_select( + node_ebd.reshape(-1, self.n_dim), + n2e_index, + 0, + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) else: edge_info = None @@ -592,16 +877,37 @@ def forward( nlist, "node", ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) ) * sw.unsqueeze(-1) + node_edge_update = ( + (paddle.sum(node_edge_update, axis=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) - node_edge_update = paddle.sum(node_edge_update, axis=-2) / self.nnei if self.n_multi_edge_message > 1: - # nb x nloc x nnei x h x n_dim + # nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.reshape( [nb, nloc, self.n_multi_edge_message, self.n_dim] ) for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) else: n_update_list.append(node_edge_update) # update node_ebd @@ -620,6 +926,15 @@ def forward( nlist, "edge", ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) ) e_update_list.append(edge_self_update) @@ -641,48 +956,66 @@ def forward( edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[:, :, :, : self.e_a_compress_dim] + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd edge_ebd_for_angle = edge_ebd - # nb x nloc x a_nnei x e_dim - edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_for_angle = paddle.where( - a_nlist_mask.unsqueeze(-1), - edge_for_angle, - paddle.zeros_like(edge_for_angle), - ).astype(edge_for_angle.dtype) + if not self.use_dynamic_sel: + # nb x nloc x a_nnei x e_dim + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_ebd_for_angle = edge_ebd_for_angle.masked_fill( + ~a_nlist_mask.unsqueeze(-1), 0.0 + ) if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim - node_for_angle_info = paddle.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - [1, 1, self.a_sel, self.a_sel, 1], + # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim + node_for_angle_info = ( + paddle.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else paddle.index_select( + node_ebd_for_angle.reshape([-1, self.n_a_compress_dim]), + n2a_index, + 0, + ) ) - # nb x nloc x (a_nnei) x a_nnei x edge_ebd - edge_for_angle_i = paddle.tile( - edge_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + + # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim + edge_for_angle_k = ( + paddle.tile( + edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + ) + if not self.use_dynamic_sel + else paddle.index_select(edge_ebd_for_angle, eik2a_index, 0) ) - # nb x nloc x a_nnei x (a_nnei) x e_dim - edge_for_angle_j = paddle.tile( - edge_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim + edge_for_angle_j = ( + paddle.tile( + edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + ) + if not self.use_dynamic_sel + else paddle.index_select(edge_ebd_for_angle, eij2a_index, 0) ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) edge_for_angle_info = paddle.concat( - [edge_for_angle_i, edge_for_angle_j], axis=-1 + [edge_for_angle_k, edge_for_angle_j], axis=1 ) angle_info_list = [angle_ebd] angle_info_list.append(node_for_angle_info) angle_info_list.append(edge_for_angle_info) # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = paddle.concat(angle_info_list, axis=-1) + # [OR] + # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) + angle_info = paddle.cat(angle_info_list, axis=1) else: angle_info = None # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim + # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim if not self.optim_update: assert angle_info is not None edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) @@ -691,32 +1024,59 @@ def forward( self.optim_angle_update( angle_ebd, node_ebd_for_angle, - edge_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, "edge", ) ) - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = paddle.sum( - weighted_edge_angle_update, axis=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = paddle.concat( - [ - reduced_edge_angle_update, - paddle.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - ).to(device=edge_ebd.place), - ], - axis=2, - ) + if not self.use_dynamic_sel: + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update + ) + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = paddle.sum( + weighted_edge_angle_update, axis=-2 + ) / (self.a_sel**0.5) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = paddle.concat( + [ + reduced_edge_angle_update, + paddle.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + ), + ], + axis=2, + ) + else: + # n_angle x e_dim + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + # n_edge x e_dim + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + if not self.smooth_edge_update: # will be deprecated in the future + # not support dynamic index, will pass anyway + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) full_mask = paddle.concat( [ a_nlist_mask, @@ -727,8 +1087,8 @@ def forward( ], axis=-1, ) - padding_edge_angle_update = paddle.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + padding_edge_angle_update = padding_edge_angle_update.masked_fill( + ~full_mask.unsqueeze(-1), edge_ebd ) e_update_list.append( self.act(self.edge_angle_linear2(padding_edge_angle_update)) @@ -746,7 +1106,17 @@ def forward( self.optim_angle_update( angle_ebd, node_ebd_for_angle, - edge_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, "angle", ) ) diff --git a/deepmd/pd/model/descriptor/repflows.py b/deepmd/pd/model/descriptor/repflows.py index 3200c26dba..2214116bbc 100644 --- a/deepmd/pd/model/descriptor/repflows.py +++ b/deepmd/pd/model/descriptor/repflows.py @@ -19,6 +19,9 @@ from deepmd.pd.model.network.mlp import ( MLPLayer, ) +from deepmd.pd.model.network.utils import ( + get_graph_index, +) from deepmd.pd.utils import ( env, ) @@ -109,12 +112,35 @@ class DescrptBlockRepflows(DescriptorBlock): smooth_edge_update : bool, optional Whether to make edge update smooth. If True, the edge update from angle message will not use self as padding. + edge_init_use_dist : bool, optional + Whether to use direct distance r to initialize the edge features instead of 1/r. + Note that when using this option, the activation function will not be used when initializing edge features. + use_exp_switch : bool, optional + Whether to use an exponential switch function instead of a polynomial one in the neighbor update. + The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance + `r` approaches the cutoff radius `rcut`. Specifically, the function is defined as: + s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut. + Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully + according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff. + Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0. + use_dynamic_sel : bool, optional + Whether to dynamically select neighbors within the cutoff radius. + If True, the exact number of neighbors within the cutoff radius is used + without padding to a fixed selection numbers. + When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively) + to guarantee capturing all neighbors within the cutoff radius. + Note that when using dynamic selection, the `smooth_edge_update` must be True. + sel_reduce_factor : float, optional + Reduction factor applied to neighbor-scale normalization when `use_dynamic_sel` is True. + In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` + or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, + accommodating larger selection numbers. + use_loc_mapping : bool, Optional + Whether to use local atom index mapping in training or non-parallel inference. + When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation. optim_update : bool, optional Whether to enable the optimized update method. Uses a more efficient process when enabled. Defaults to True - use_loc_mapping : bool, Optional - Whether to use local atom index mapping in training or non-parallel inference. - Not supported yet in Paddle. ntypes : int Number of element types activation_function : str, optional @@ -162,9 +188,11 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, + use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, - use_loc_mapping: bool = False, + use_loc_mapping: bool = True, optim_update: bool = True, seed: Optional[Union[int, list[int]]] = None, ) -> None: @@ -195,13 +223,21 @@ def __init__( self.fix_stat_std = fix_stat_std self.set_stddev_constant = fix_stat_std != 0.0 self.a_compress_use_split = a_compress_use_split + self.use_loc_mapping = use_loc_mapping self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update - self.use_dynamic_sel = use_dynamic_sel # not supported yet + self.edge_init_use_dist = edge_init_use_dist + self.use_exp_switch = use_exp_switch + self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor - assert not self.use_dynamic_sel, "Dynamic selection is not supported yet." - self.use_loc_mapping = use_loc_mapping - assert not self.use_loc_mapping, "Local mapping is not supported yet." + if self.use_dynamic_sel and not self.smooth_edge_update: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + if self.sel_reduce_factor <= 0: + raise ValueError( + f"`sel_reduce_factor` must be > 0, got {self.sel_reduce_factor}" + ) self.n_dim = n_dim self.e_dim = e_dim @@ -366,9 +402,9 @@ def forward( mapping: Optional[paddle.Tensor] = None, comm_dict: Optional[dict[str, paddle.Tensor]] = None, ): - if comm_dict is None: + parallel_mode = comm_dict is not None + if not parallel_mode: assert mapping is not None - assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 atype = extended_atype[:, :nloc] @@ -385,30 +421,13 @@ def forward( self.e_rcut, self.e_rcut_smth, protection=self.env_protection, + use_exp_switch=self.use_exp_switch, ) nlist_mask = nlist != -1 sw = paddle.squeeze(sw, -1) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) - # [nframes, nloc, tebd_dim] - if comm_dict is None: - if paddle.in_dynamic_mode(): - assert isinstance(extended_atype_embd, paddle.Tensor) - atype_embd = extended_atype_embd[:, :nloc, :] - if paddle.in_dynamic_mode(): - assert atype_embd.shape == [nframes, nloc, self.n_dim] - else: - atype_embd = extended_atype_embd - if paddle.in_dynamic_mode(): - assert isinstance(atype_embd, paddle.Tensor) - node_ebd = self.act(atype_embd) - n_dim = node_ebd.shape[-1] - # nb x nloc x nnei x 1, nb x nloc x nnei x 3 - edge_input, h2 = paddle.split(dmatrix, [1, 3], axis=-1) - # nb x nloc x nnei x e_dim - edge_ebd = self.act(self.edge_embd(edge_input)) - # get angle nlist (maybe smaller) a_dist_mask = (paddle.linalg.norm(diff, axis=-1) < self.a_rcut)[ :, :, : self.a_sel @@ -424,13 +443,34 @@ def forward( self.a_rcut, self.a_rcut_smth, protection=self.env_protection, + use_exp_switch=self.use_exp_switch, ) a_nlist_mask = a_nlist != -1 a_sw = paddle.squeeze(a_sw, -1) # beyond the cutoff sw should be 0.0 a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 a_nlist[a_nlist == -1] = 0 + # get node embedding + # [nframes, nloc, tebd_dim] + assert extended_atype_embd is not None + atype_embd = extended_atype_embd[:, :nloc, :] + if paddle.in_dynamic_mode(): + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + assert isinstance(atype_embd, paddle.Tensor) # for jit + node_ebd = self.act(atype_embd) + n_dim = node_ebd.shape[-1] + + # get edge and angle embedding input + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + edge_input, h2 = paddle.split(dmatrix, [1, 3], axis=-1) + if self.edge_init_use_dist: + # nb x nloc x nnei x 1 + edge_input = paddle.linalg.norm(diff, axis=-1, keepdim=True) + # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( paddle.linalg.norm(a_diff, axis=-1, keepdim=True) + 1e-6 @@ -440,18 +480,53 @@ def forward( # nf x nloc x a_nnei x a_nnei # 1 - 1e-6 for paddle.acos stability cosine_ij = paddle.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) - # nf x nloc x a_nnei x a_nnei x 1 - cosine_ij = cosine_ij.unsqueeze(-1) / (paddle.pi**0.5) - # nf x nloc x a_nnei x a_nnei x a_dim - angle_ebd = self.angle_embd(cosine_ij).reshape( - [nframes, nloc, self.a_sel, self.a_sel, self.a_dim] - ) + angle_input = cosine_ij.unsqueeze(-1) / (paddle.pi**0.5) + + if not parallel_mode and self.use_loc_mapping: + assert mapping is not None + # convert nlist from nall to nloc index + nlist = paddle.take_along_axis( + mapping, + nlist.reshape([nframes, -1]), + 1, + broadcast=False, + ).reshape(nlist.shape) + if self.use_dynamic_sel: + # get graph index + edge_index, angle_index = get_graph_index( + nlist, + nlist_mask, + a_nlist_mask, + nall, + use_loc_mapping=self.use_loc_mapping, + ) + # flat all the tensors + # n_edge x 1 + edge_input = edge_input[nlist_mask] + # n_edge x 3 + h2 = h2[nlist_mask] + # n_edge x 1 + sw = sw[nlist_mask] + # nb x nloc x a_nnei x a_nnei + a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] + # n_angle x 1 + angle_input = angle_input[a_nlist_mask] + # n_angle x 1 + a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] + else: + # avoid jit assertion + edge_index = angle_index = paddle.zeros([1, 3], dtype=nlist.dtype) + # get edge and angle embedding + # nb x nloc x nnei x e_dim [OR] n_edge x e_dim + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) + # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim + angle_ebd = self.angle_embd(angle_input) - # set all padding positions to index of 0 - # if the a neighbor is real or not is indicated by nlist_mask - nlist[nlist == -1] = 0 # nb x nall x n_dim - if comm_dict is None: + if not parallel_mode: assert mapping is not None mapping = ( mapping.reshape([nframes, nall]) @@ -460,8 +535,8 @@ def forward( ) for idx, ll in enumerate(self.layers): # node_ebd: nb x nloc x n_dim - # node_ebd_ext: nb x nall x n_dim - if comm_dict is None: + # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode + if not parallel_mode: assert mapping is not None node_ebd_ext = paddle.take_along_axis( node_ebd, mapping, 1, broadcast=False @@ -479,12 +554,27 @@ def forward( a_nlist, a_nlist_mask, a_sw, + edge_index=edge_index, + angle_index=angle_index, ) # nb x nloc x 3 x e_dim - h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + h2g2 = ( + RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + if not self.use_dynamic_sel + else RepFlowLayer._cal_hg_dynamic( + edge_ebd, + h2, + sw, + owner=edge_index[:, 0], + num_owner=nframes * nloc, + nb=nframes, + nloc=nloc, + scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5), + ) + ) # (nb x nloc) x e_dim x 3 - rot_mat = paddle.transpose(h2g2, (0, 1, 3, 2)) + rot_mat = paddle.transpose(h2g2, [0, 1, 3, 2]) return ( node_ebd, diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py index 32f88dd1d3..bf97dfec5e 100644 --- a/deepmd/pd/model/descriptor/repformers.py +++ b/deepmd/pd/model/descriptor/repformers.py @@ -267,10 +267,10 @@ def __init__( wanted_shape = (self.ntypes, self.nnei, 4) mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( - device=env.DEVICE + env.DEVICE ) stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( - device=env.DEVICE + env.DEVICE ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) @@ -503,11 +503,11 @@ def compute_input_stats( mean, stddev = env_mat_stat() if not self.set_davg_zero: paddle.assign( - paddle.to_tensor(mean, dtype=self.mean.dtype).to(device=env.DEVICE), + paddle.to_tensor(mean, dtype=self.mean.dtype).to(env.DEVICE), self.mean, ) # pylint: disable=no-explicit-dtype paddle.assign( - paddle.to_tensor(stddev, dtype=self.stddev.dtype).to(device=env.DEVICE), + paddle.to_tensor(stddev, dtype=self.stddev.dtype).to(env.DEVICE), self.stddev, ) # pylint: disable=no-explicit-dtype diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index cf5b1f835c..3f45910392 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -71,6 +71,7 @@ def to_bool(flag: int | bool | str) -> bool: CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False PRECISION_DICT = { "float16": paddle.float16, @@ -198,6 +199,7 @@ def enable_prim(enable: bool = True): __all__ = [ "CACHE_PER_SYS", "CINN", + "CUSTOM_OP_USE_JIT", "DEFAULT_PRECISION", "DEVICE", "ENERGY_BIAS_TRAINABLE", diff --git a/deepmd/pd/utils/preprocess.py b/deepmd/pd/utils/preprocess.py index 3e047c1b8b..3be42b522e 100644 --- a/deepmd/pd/utils/preprocess.py +++ b/deepmd/pd/utils/preprocess.py @@ -10,9 +10,20 @@ def compute_smooth_weight(distance, rmin: float, rmax: float): """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") - min_mask = distance <= rmin - max_mask = distance >= rmax - mid_mask = paddle.logical_not(paddle.logical_or(min_mask, max_mask)) + distance = paddle.clip(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 - return vv * mid_mask.astype(vv.dtype) + min_mask.astype(vv.dtype) + uu2 = uu * uu + vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1 + return vv + + +def compute_exp_sw(distance, rmin: float, rmax: float): + """Compute the exponential switch function for neighbor update.""" + if rmin >= rmax: + raise ValueError("rmin should be less than rmax.") + distance = paddle.clip(distance, min=0.0, max=rmax) + C = 20 + a = C / rmin + b = rmin + exp_sw = paddle.exp(-paddle.exp(a * (distance - b))) + return exp_sw diff --git a/deepmd/pd/utils/spin.py b/deepmd/pd/utils/spin.py index 934fb3762a..27bc355877 100644 --- a/deepmd/pd/utils/spin.py +++ b/deepmd/pd/utils/spin.py @@ -21,7 +21,6 @@ def concat_switch_virtual( extended_tensor_updated = paddle.zeros( out_shape, dtype=extended_tensor.dtype, - device=extended_tensor.place, ) extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc] diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index a756491a8d..eeda778b37 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -20,6 +20,9 @@ ) from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.pd.utils import ( + env, +) from .env import ( DEVICE, @@ -32,15 +35,129 @@ ) +def silut_forward( + x: paddle.Tensor, threshold: float, slope: float, const_val: float +) -> paddle.Tensor: + sig = F.sigmoid(x) + silu = x * sig + tanh = paddle.tanh(slope * (x - threshold)) + const_val + return paddle.where(x >= threshold, tanh, silu) + + +def silut_backward( + x: paddle.Tensor, grad_output: paddle.Tensor, threshold: float, slope: float +) -> paddle.Tensor: + sig = F.sigmoid(x) + grad_silu = sig * (1 + x * (1 - sig)) + + tanh = paddle.tanh(slope * (x - threshold)) + grad_tanh = slope * (1 - tanh * tanh) + + grad = paddle.where(x >= threshold, grad_tanh, grad_silu) + return grad * grad_output + + +def silut_double_backward( + x: paddle.Tensor, + grad_grad_output: paddle.Tensor, + grad_output: paddle.Tensor, + threshold: float, + slope: float, +) -> tuple[paddle.Tensor, paddle.Tensor]: + # SiLU branch + sig = F.sigmoid(x) + + sig_prime = sig * (1 - sig) + grad_silu = sig + x * sig_prime + grad_grad_silu = sig_prime * (2 + x * (1 - 2 * sig)) + + # Tanh branch + tanh = paddle.tanh(slope * (x - threshold)) + tanh_square = tanh * tanh # .square is slow for jit.script! + grad_tanh = slope * (1 - tanh_square) + grad_grad_tanh = -2 * slope * tanh * grad_tanh + + grad = paddle.where(x >= threshold, grad_tanh, grad_silu) + grad_grad = paddle.where(x >= threshold, grad_grad_tanh, grad_grad_silu) + return grad_output * grad_grad * grad_grad_output, grad * grad_grad_output + + +class SiLUTScript(paddle.nn.Layer): + def __init__(self, threshold: float = 3.0): + super().__init__() + self.threshold = threshold + + # Precompute parameters for the tanh replacement + sigmoid_threshold = 1 / (1 + np.exp(-threshold)) + self.slope = float( + sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold) + ) + self.const_val = float(threshold * sigmoid_threshold) + self.get_script_code() + + def get_script_code(self): + silut_forward_script = paddle.jit.to_static(silut_forward, full_graph=True) + # silut_forward_script = (silut_forward) + silut_backward_script = paddle.jit.to_static(silut_backward, full_graph=True) + # silut_backward_script = (silut_backward) + silut_double_backward_script = paddle.jit.to_static( + silut_double_backward, full_graph=True + ) + # silut_double_backward_script = (silut_double_backward) + + class SiLUTFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, threshold, slope, const_val): + ctx.save_for_backward(x) + ctx.threshold = threshold + ctx.slope = slope + ctx.const_val = const_val + return silut_forward_script(x, threshold, slope, const_val) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensor() + threshold = ctx.threshold + slope = ctx.slope + + grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope) + return grad_input + + class SiLUTGradFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, grad_output, threshold, slope): + ctx.threshold = threshold + ctx.slope = slope + grad_input = silut_backward_script(x, grad_output, threshold, slope) + ctx.save_for_backward(x, grad_output) + return grad_input + + @staticmethod + def backward(ctx, grad_grad_output): + (x, grad_output) = ctx.saved_tensor() + threshold = ctx.threshold + slope = ctx.slope + + grad_input, grad_mul_grad_grad_output = silut_double_backward_script( + x, grad_grad_output, grad_output, threshold, slope + ) + return grad_input, grad_mul_grad_grad_output + + self.SiLUTFunction = SiLUTFunction + + def forward(self, x): + return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val) + + class SiLUT(paddle.nn.Layer): def __init__(self, threshold=3.0): super().__init__() def sigmoid(x): - return paddle.nn.functional.sigmoid(x) + return F.sigmoid(x) def silu(x): - return paddle.nn.functional.silu(x) + return F.silu(x) def silu_grad(x): sig = sigmoid(x) @@ -76,7 +193,12 @@ def __init__(self, activation: str | None): threshold = ( float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 ) - self.silut = SiLUT(threshold=threshold) + if env.CUSTOM_OP_USE_JIT: + # for efficient training but can not be jit + self.silut = SiLUTScript(threshold=threshold) + # self.silut = paddle.nn.Identity() + else: + self.silut = SiLUT(threshold=threshold) else: self.silut = None From 7e07126e7c41f92dc25e9758f58aa90791c09375 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Jun 2025 10:32:12 +0800 Subject: [PATCH 03/13] fix typo --- .pre-commit-config.yaml | 52 ++++++++++----------- deepmd/pd/model/descriptor/repflow_layer.py | 2 +- deepmd/pd/utils/utils.py | 4 -- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77dab6f3aa..cee3d7f2ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$) # markdown, yaml, CSS, javascript - # - repo: https://github.com/pre-commit/mirrors-prettier - # rev: v4.0.0-alpha.8 - # hooks: - # - id: prettier - # types_or: [markdown, yaml, css] - # # workflow files cannot be modified by pre-commit.ci - # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + types_or: [markdown, yaml, css] + # workflow files cannot be modified by pre-commit.ci + exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.11.0-1 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - # - repo: https://github.com/njzjz/mirrors-bibtex-tidy - # rev: v1.13.0 - # hooks: - # - id: bibtex-tidy - # args: - # - --curly - # - --numeric - # - --align=13 - # - --blank-lines - # # disable sort: the order of keys and fields has explict meanings - # #- --sort=key - # - --duplicates=key,doi,citation,abstract - # - --merge=combine - # #- --sort-fields - # #- --strip-comments - # - --trailing-commas - # - --encode-urls - # - --remove-empty-fields - # - --wrap=80 + - repo: https://github.com/njzjz/mirrors-bibtex-tidy + rev: v1.13.0 + hooks: + - id: bibtex-tidy + args: + - --curly + - --numeric + - --align=13 + - --blank-lines + # disable sort: the order of keys and fields has explict meanings + #- --sort=key + - --duplicates=key,doi,citation,abstract + - --merge=combine + #- --sort-fields + #- --strip-comments + - --trailing-commas + - --encode-urls + - --remove-empty-fields + - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index 78cfa7a56d..73748943f8 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -849,7 +849,7 @@ def forward( edge_info = paddle.concat( [ paddle.index_select( - node_ebd.reshape(-1, self.n_dim), + node_ebd.reshape([-1, self.n_dim]), n2e_index, 0, ), diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index eeda778b37..175ac5019b 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -97,13 +97,10 @@ def __init__(self, threshold: float = 3.0): def get_script_code(self): silut_forward_script = paddle.jit.to_static(silut_forward, full_graph=True) - # silut_forward_script = (silut_forward) silut_backward_script = paddle.jit.to_static(silut_backward, full_graph=True) - # silut_backward_script = (silut_backward) silut_double_backward_script = paddle.jit.to_static( silut_double_backward, full_graph=True ) - # silut_double_backward_script = (silut_double_backward) class SiLUTFunction(paddle.autograd.PyLayer): @staticmethod @@ -196,7 +193,6 @@ def __init__(self, activation: str | None): if env.CUSTOM_OP_USE_JIT: # for efficient training but can not be jit self.silut = SiLUTScript(threshold=threshold) - # self.silut = paddle.nn.Identity() else: self.silut = SiLUT(threshold=threshold) else: From 5103f5ded158e4452e087b2b593cc88a5acdc6c9 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Jun 2025 10:39:22 +0800 Subject: [PATCH 04/13] upload missing file --- deepmd/pd/model/descriptor/repflow_layer.py | 2 +- deepmd/pd/model/network/utils.py | 135 ++++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 deepmd/pd/model/network/utils.py diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index 73748943f8..40e7a35e6e 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -1010,7 +1010,7 @@ def forward( # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) # [OR] # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = paddle.cat(angle_info_list, axis=1) + angle_info = paddle.concat(angle_info_list, axis=1) else: angle_info = None diff --git a/deepmd/pd/model/network/utils.py b/deepmd/pd/model/network/utils.py new file mode 100644 index 0000000000..f09686a264 --- /dev/null +++ b/deepmd/pd/model/network/utils.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import paddle + + +def aggregate( + data: paddle.Tensor, + owners: paddle.Tensor, + average: bool = True, + num_owner: Optional[int] = None, +) -> paddle.Tensor: + """ + Aggregate rows in data by specifying the owners. + + Parameters + ---------- + data : data tensor to aggregate [n_row, feature_dim] + owners : specify the owner of each row [n_row, 1] + average : if True, average the rows, if False, sum the rows. + Default = True + num_owner : the number of owners, this is needed if the + max idx of owner is not presented in owners tensor + Default = None + + Returns + ------- + output: [num_owner, feature_dim] + """ + bin_count = paddle.bincount(owners) + bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count)) + + if (num_owner is not None) and (bin_count.shape[0] != num_owner): + difference = num_owner - bin_count.shape[0] + bin_count = paddle.concat([bin_count, paddle.ones_like(difference)]) + + # make sure this operation is done on the same device of data and owners + output = paddle.zeros([bin_count.shape[0], data.shape[1]]) + output = output.index_add_(owners, 0, data) + if average: + output = (output.T / bin_count).T + return output + + +def get_graph_index( + nlist: paddle.Tensor, + nlist_mask: paddle.Tensor, + a_nlist_mask: paddle.Tensor, + nall: int, +): + """ + Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. + + Parameters + ---------- + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + nall + The number of extended atoms. + + Returns + ------- + edge_index : n_edge x 2 + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : n_angle x 3 + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). + """ + nf, nloc, nnei = nlist.shape + _, _, a_nnei = a_nlist_mask.shape + # nf x nloc x nnei x nnei + # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] + a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] + n_edge = nlist_mask.sum().item() + # n_angle = a_nlist_mask_3d.sum().item() + + # following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index + + # 1. atom graph + # node(i) to edge(ij) index_select; edge(ij) to node aggregate + nlist_loc_index = paddle.arange(0, nf * nloc, dtype=nlist.dtype).to(nlist.place) + # nf x nloc x nnei + n2e_index = nlist_loc_index.reshape([nf, nloc, 1]).expand([-1, -1, nnei]) + # n_edge + n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] + + # node_ext(j) to edge(ij) index_select + frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall + shifted_nlist = nlist + frame_shift[:, None, None] + # n_edge + n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] + + # 2. edge graph + # node(i) to angle(ijk) index_select + n2a_index = nlist_loc_index.reshape([nf, nloc, 1, 1]).expand( + [-1, -1, a_nnei, a_nnei] + ) + # n_angle + n2a_index = n2a_index[a_nlist_mask_3d] + + # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate + edge_id = paddle.arange(0, n_edge, dtype=nlist.dtype) + # nf x nloc x nnei + edge_index = paddle.zeros([nf, nloc, nnei], dtype=nlist.dtype) + edge_index[nlist_mask] = edge_id + # only cut a_nnei neighbors, to avoid nnei x nnei + edge_index = edge_index[:, :, :a_nnei] + edge_index_ij = edge_index.unsqueeze(-1).expand([-1, -1, -1, a_nnei]) + # n_angle + eij2a_index = edge_index_ij[a_nlist_mask_3d] + + # edge(ik) to angle(ijk) index_select + edge_index_ik = edge_index.unsqueeze(-2).expand([-1, -1, a_nnei, -1]) + # n_angle + eik2a_index = edge_index_ik[a_nlist_mask_3d] + + return paddle.concat( + [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], axis=-1 + ), paddle.concat( + [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], + axis=-1, + ) From 90bd05bba97c1b6d1fc221175a0dbca48a143f18 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Jun 2025 13:29:53 +0800 Subject: [PATCH 05/13] fix --- deepmd/pd/model/descriptor/repflow_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index 40e7a35e6e..7de88d8bd9 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -1002,7 +1002,7 @@ def forward( ) # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) edge_for_angle_info = paddle.concat( - [edge_for_angle_k, edge_for_angle_j], axis=1 + [edge_for_angle_k, edge_for_angle_j], axis=-1 ) angle_info_list = [angle_ebd] angle_info_list.append(node_for_angle_info) @@ -1010,7 +1010,7 @@ def forward( # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) # [OR] # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = paddle.concat(angle_info_list, axis=1) + angle_info = paddle.concat(angle_info_list, axis=-1) else: angle_info = None From d264faef4fccc8d20d86540844ab8d251558b6ed Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Jun 2025 11:18:24 +0800 Subject: [PATCH 06/13] Update deepmd/pd/train/training.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: HydrogenSulfate <490868991@qq.com> --- deepmd/pd/train/training.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 07cd738a1c..984535b5d7 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -607,10 +607,13 @@ def warm_up_linear(step, warmup_steps): ) backend = "CINN" if CINN else None - # NOTE: This is a trick to decide the right input_spec for wrapper.forward - _, label_dict, _ = self.get_data(is_train=True, task_key="Default") - label_dict_spec = { + # Use appropriate task_key for multi-task scenarios + sample_task_key = self.model_keys[0] if self.multi_task else "Default" + _, label_dict, _ = self.get_data(is_train=True, task_key=sample_task_key) + + # Define specification templates + spec_templates = { "find_box": np.float32(1.0), "find_coord": np.float32(1.0), "find_numb_copy": np.float32(0.0), @@ -623,19 +626,8 @@ def warm_up_linear(step, warmup_steps): "virial": static.InputSpec([1, 9], "float64", name="virial"), "natoms": static.InputSpec([1, -1], "int32", name="natoms"), } - if "virial" not in label_dict: - label_dict_spec.pop("virial") - if "find_virial" not in label_dict: - label_dict_spec.pop("find_virial") - if "energy" not in label_dict: - label_dict_spec.pop("energy") - if "find_energy" not in label_dict: - label_dict_spec.pop("find_energy") - if "force" not in label_dict: - label_dict_spec.pop("force") - if "find_force" not in label_dict: - label_dict_spec.pop("find_force") - + # Build spec only for keys present in sample data + label_dict_spec = {k: spec_templates[k] for k in label_dict.keys() if k in spec_templates} self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ From 6528a85b31bf178fb158845800ec8abc0b39bb88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 03:20:01 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pd/train/training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 984535b5d7..ca9e28726c 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -627,7 +627,9 @@ def warm_up_linear(step, warmup_steps): "natoms": static.InputSpec([1, -1], "int32", name="natoms"), } # Build spec only for keys present in sample data - label_dict_spec = {k: spec_templates[k] for k in label_dict.keys() if k in spec_templates} + label_dict_spec = { + k: spec_templates[k] for k in label_dict.keys() if k in spec_templates + } self.wrapper.forward = jit.to_static( backend=backend, input_spec=[ From 3a6438e973c1de5c94e937cf839679005ab3090e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Jun 2025 11:34:37 +0800 Subject: [PATCH 08/13] fix --- deepmd/pd/train/training.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index ca9e28726c..bb0a8987b8 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -608,9 +608,7 @@ def warm_up_linear(step, warmup_steps): backend = "CINN" if CINN else None # NOTE: This is a trick to decide the right input_spec for wrapper.forward - # Use appropriate task_key for multi-task scenarios - sample_task_key = self.model_keys[0] if self.multi_task else "Default" - _, label_dict, _ = self.get_data(is_train=True, task_key=sample_task_key) + _, label_dict, _ = self.get_data(is_train=True) # Define specification templates spec_templates = { From d236285db5e754de871d10b640a590ebb1f6a68d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 19 Jun 2025 13:40:24 +0800 Subject: [PATCH 09/13] refine code --- .pre-commit-config.yaml | 52 ++++++++++++++++++------------------- deepmd/pd/train/training.py | 31 +++++++++++++++------- deepmd/pd/utils/env.py | 3 ++- 3 files changed, 50 insertions(+), 36 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cee3d7f2ce..77dab6f3aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$) # markdown, yaml, CSS, javascript - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - types_or: [markdown, yaml, css] - # workflow files cannot be modified by pre-commit.ci - exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + # - repo: https://github.com/pre-commit/mirrors-prettier + # rev: v4.0.0-alpha.8 + # hooks: + # - id: prettier + # types_or: [markdown, yaml, css] + # # workflow files cannot be modified by pre-commit.ci + # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.11.0-1 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - - repo: https://github.com/njzjz/mirrors-bibtex-tidy - rev: v1.13.0 - hooks: - - id: bibtex-tidy - args: - - --curly - - --numeric - - --align=13 - - --blank-lines - # disable sort: the order of keys and fields has explict meanings - #- --sort=key - - --duplicates=key,doi,citation,abstract - - --merge=combine - #- --sort-fields - #- --strip-comments - - --trailing-commas - - --encode-urls - - --remove-empty-fields - - --wrap=80 + # - repo: https://github.com/njzjz/mirrors-bibtex-tidy + # rev: v1.13.0 + # hooks: + # - id: bibtex-tidy + # args: + # - --curly + # - --numeric + # - --align=13 + # - --blank-lines + # # disable sort: the order of keys and fields has explict meanings + # #- --sort=key + # - --duplicates=key,doi,citation,abstract + # - --merge=combine + # #- --sort-fields + # #- --strip-comments + # - --trailing-commas + # - --encode-urls + # - --remove-empty-fields + # - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index bb0a8987b8..5cd60716c0 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import contextlib import functools import logging import time @@ -18,6 +19,7 @@ from paddle.distributed import ( fleet, ) +from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu from paddle.framework import ( core, ) @@ -741,16 +743,27 @@ def step(_step_id, task_key="Default") -> None: pref_lr = _lr.start_lr else: pref_lr = cur_lr - with nvprof_context(enable_profiling, "Forward pass"): - model_pred, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=task_key, - ) + sync_context = ( + self.wrapper.no_sync + if self.world_size > 1 + else contextlib.nullcontext + ) + with sync_context(): + with nvprof_context(enable_profiling, "Forward pass"): + model_pred, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=task_key, + ) + + with nvprof_context(enable_profiling, "Backward pass"): + loss.backward() - with nvprof_context(enable_profiling, "Backward pass"): - loss.backward() + if self.world_size > 1: + # fuse + allreduce manually before optimization if use DDP + no_sync + # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 + hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) if self.gradient_max_norm > 0.0: with nvprof_context(enable_profiling, "Gradient clip"): diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 3f45910392..4c34c551b4 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -27,7 +27,8 @@ ncpus = os.cpu_count() NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(0, ncpus))) # Make sure DDP uses correct device if applicable -LOCAL_RANK = paddle.distributed.get_rank() +LOCAL_RANK = os.environ.get("PADDLE_LOCAL_RANK") +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) if os.environ.get("DEVICE") == "cpu" or paddle.device.cuda.device_count() <= 0: DEVICE = "cpu" From 7ce9af1eb18c2c5ae450e78bfc0169d6332ed612 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 19 Jun 2025 14:04:17 +0800 Subject: [PATCH 10/13] update parallel test code --- examples/water/dpa3/input_torch.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/water/dpa3/input_torch.json b/examples/water/dpa3/input_torch.json index 90e81b5403..e84f11821c 100644 --- a/examples/water/dpa3/input_torch.json +++ b/examples/water/dpa3/input_torch.json @@ -85,7 +85,7 @@ "batch_size": 1, "_comment": "that's all" }, - "numb_steps": 1000000, + "numb_steps": 2000, "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, From 78fe1b8ce957b9786ea0150f47460822cacd5c42 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 17:52:02 +0800 Subject: [PATCH 11/13] update code --- examples/water/dpa3/run.sh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 examples/water/dpa3/run.sh diff --git a/examples/water/dpa3/run.sh b/examples/water/dpa3/run.sh new file mode 100644 index 0000000000..db9d895ac6 --- /dev/null +++ b/examples/water/dpa3/run.sh @@ -0,0 +1,17 @@ +# unset PADDLE_ELASTIC_JOB_ID +# unset PADDLE_TRAINER_ENDPOINTS +# unset DISTRIBUTED_TRAINER_ENDPOINTS +# unset FLAGS_START_PORT +# unset PADDLE_ELASTIC_TIMEOUT +# export NNODES=1 +# export PADDLE_TRAINERS_NUM=1 +unset CUDA_DEVICE_MAX_CONNECTIONS + +HDFS_USE_FILE_LOCKING=0 python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir "logs" dp --pd train input_torch.json -l dp_train.log + +# NUM_WORKERS=0 HDFS_USE_FILE_LOCKING=0 python -m paddle.distributed.launch + +# python -m paddle.distributed.launch \ +# --gpus=0,1,2,3 \ +# --ips=10.67.200.17,10.67.200.11,10.67.200.13,10.67.200.15 \ +# dp --pd train input_torch.json -l dp_train.log \ No newline at end of file From 35783bbeddff4470ec23c68b9f8afcb91419aade Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 13 Aug 2025 20:32:11 +0800 Subject: [PATCH 12/13] update auto ddp code --- deepmd/pd/model/model/make_model.py | 56 +++++++++++++++++++++++------ deepmd/pd/train/training.py | 56 ++++++++++++++++++++--------- 2 files changed, 85 insertions(+), 27 deletions(-) diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index acb237b5ac..6474157865 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -37,6 +37,9 @@ from deepmd.utils.path import ( DPPath, ) +import paddle.distributed as dist +from paddle.distributed import fleet +import functools def make_model(T_AtomicModel: type[BaseAtomicModel]): @@ -163,29 +166,60 @@ def forward_common( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam + # ( + # extended_coord, + # extended_atype, + # mapping, + # nlist, + # ) = extend_input_and_build_neighbor_list( + # cc, + # atype, + # self.get_rcut(), + # self.get_sel(), + # # types will be distinguished in the lower interface, + # # so it doesn't need to be distinguished here + # mixed_types=True, + # box=bb, + # ) + wrapped_func_1 = dist.local_map( + func=lambda a,b,c: extend_input_and_build_neighbor_list(a,b,self.get_rcut(), self.get_sel(), True, c), + in_placements=[ele.placements for ele in [cc, atype, bb]], + out_placements=[[dist.Shard(0)] for _ in range(4)], + process_mesh=fleet.auto.get_mesh() + ) + ( extended_coord, extended_atype, mapping, nlist, - ) = extend_input_and_build_neighbor_list( + ) = wrapped_func_1( cc, atype, - self.get_rcut(), - self.get_sel(), - # types will be distinguished in the lower interface, - # so it doesn't need to be distinguished here - mixed_types=True, - box=bb, + bb, + ) + # model_predict_lower = self.forward_common_lower( + # extended_coord, + # extended_atype, + # nlist, + # mapping, + # do_atomic_virial=do_atomic_virial, + # fparam=fp, + # aparam=ap, + # ) + + wrapped_func_2 = dist.local_map( + func=functools.partial(self.forward_common_lower, do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap), + in_placements=[ele.placements for ele in [extended_coord, extended_atype, nlist, mapping]], + out_placements=[[dist.Shard(0)] for _ in range(6)], + process_mesh=fleet.auto.get_mesh(), + reshard_inputs=True ) - model_predict_lower = self.forward_common_lower( + model_predict_lower = wrapped_func_2( extended_coord, extended_atype, nlist, mapping, - do_atomic_virial=do_atomic_virial, - fparam=fp, - aparam=ap, ) model_predict = communicate_extended_output( model_predict_lower, diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 5cd60716c0..782070a57e 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -26,7 +26,9 @@ from paddle.io import ( DataLoader, ) - +import paddle.distributed as dist +from paddle.distributed import fleet +import functools from deepmd.common import ( symlink_prefix_files, ) @@ -101,6 +103,11 @@ def __init__( Args: - config: The Dict-like configuration with training options. """ + from paddle.distributed import fleet + mesh_dims = [("dp", 32)] + fleet.auto.create_mesh(mesh_dims) + fleet.init(is_collective=True) + enable_prim(True) if init_model is not None: resume_model = init_model @@ -748,22 +755,39 @@ def step(_step_id, task_key="Default") -> None: if self.world_size > 1 else contextlib.nullcontext ) - with sync_context(): - with nvprof_context(enable_profiling, "Forward pass"): - model_pred, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=task_key, - ) - - with nvprof_context(enable_profiling, "Backward pass"): - loss.backward() + + # with sync_context(): + # with nvprof_context(enable_profiling, "Forward pass"): + # model_pred, loss, more_loss = self.wrapper( + # **input_dict, + # cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + # label=label_dict, + # task_key=task_key, + # ) + + # with nvprof_context(enable_profiling, "Backward pass"): + # loss.backward() + + # if self.world_size > 1: + # # fuse + allreduce manually before optimization if use DDP + no_sync + # # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 + # hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) + + with nvprof_context(enable_profiling, "Forward pass"): + for __key in ('coord', 'atype', 'box'): + input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)]) + for __key, _ in label_dict.items(): + if isinstance(label_dict[__key], paddle.Tensor): + label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)]) + model_pred, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=task_key, + ) - if self.world_size > 1: - # fuse + allreduce manually before optimization if use DDP + no_sync - # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 - hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) + with nvprof_context(enable_profiling, "Backward pass"): + loss.backward() if self.gradient_max_norm > 0.0: with nvprof_context(enable_profiling, "Gradient clip"): From 95c1377b37e2ecfcff693edc184d2bb59b25e1b3 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Fri, 22 Aug 2025 15:25:47 +0800 Subject: [PATCH 13/13] improve auto parallel perf --- deepmd/pd/loss/ener.py | 11 ++++- deepmd/pd/train/training.py | 74 ++++++++++++++-------------- deepmd/pd/utils/dataloader.py | 2 +- examples/water/dpa3/input_torch.json | 4 +- 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/deepmd/pd/loss/ener.py b/deepmd/pd/loss/ener.py index 09ec5ff49e..83f6949c52 100644 --- a/deepmd/pd/loss/ener.py +++ b/deepmd/pd/loss/ener.py @@ -21,6 +21,7 @@ from deepmd.utils.version import ( check_version_compatibility, ) +import paddle.distributed as dist def custom_huber_loss(predictions, targets, delta=1.0): @@ -205,7 +206,11 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy if not self.use_l1_all: - l2_ener_loss = paddle.mean(paddle.square(energy_pred - energy_label)) + + tmp = energy_pred - energy_label + logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()]) + + l2_ener_loss = paddle.mean(paddle.square(logit)) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy @@ -258,7 +263,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): force_pred = model_pred["force"] force_label = label["force"] diff_f = (force_label - force_pred).reshape([-1]) - + diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()]) + if self.relative_f is not None: force_label_3 = force_label.reshape([-1, 3]) norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f @@ -354,6 +360,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): find_virial = label.get("find_virial", 0.0) pref_v = pref_v * find_virial diff_v = label["virial"] - model_pred["virial"].reshape([-1, 9]) + diff_v = dist.reshard(diff_v, diff_v.process_mesh, [dist.Replicate()]) l2_virial_loss = paddle.mean(paddle.square(diff_v)) if not self.inference: more_loss["l2_virial_loss"] = self.display_if_exist( diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 782070a57e..54ab7689db 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -171,10 +171,7 @@ def get_dataloader_and_buffer(_data, _params): ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. _dataloader = DataLoader( _data, - batch_sampler=paddle.io.BatchSampler( - sampler=_sampler, - drop_last=False, - ), + batch_size=1, num_workers=NUM_WORKERS if dist.is_available() else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 @@ -325,17 +322,18 @@ def get_lr(lr_params): self.validation_data, self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) - training_data.print_summary( - "training", - to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights), - ) - if validation_data is not None: - validation_data.print_summary( - "validation", - to_numpy_array( - self.validation_dataloader.batch_sampler.sampler.weights - ), - ) + # no sampler, do not need print! + # training_data.print_summary( + # "training", + # to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights), + # ) + # if validation_data is not None: + # validation_data.print_summary( + # "validation", + # to_numpy_array( + # self.validation_dataloader.batch_sampler.sampler.weights + # ), + # ) else: ( self.training_dataloader, @@ -370,27 +368,27 @@ def get_lr(lr_params): validation_data[model_key], training_params["data_dict"][model_key], ) - - training_data[model_key].print_summary( - f"training in {model_key}", - to_numpy_array( - self.training_dataloader[ - model_key - ].batch_sampler.sampler.weights - ), - ) - if ( - validation_data is not None - and validation_data[model_key] is not None - ): - validation_data[model_key].print_summary( - f"validation in {model_key}", - to_numpy_array( - self.validation_dataloader[ - model_key - ].batch_sampler.sampler.weights - ), - ) + # no sampler, do not need print! + # training_data[model_key].print_summary( + # f"training in {model_key}", + # to_numpy_array( + # self.training_dataloader[ + # model_key + # ].batch_sampler.sampler.weights + # ), + # ) + # if ( + # validation_data is not None + # and validation_data[model_key] is not None + # ): + # validation_data[model_key].print_summary( + # f"validation in {model_key}", + # to_numpy_array( + # self.validation_dataloader[ + # model_key + # ].batch_sampler.sampler.weights + # ), + # ) # Learning rate self.warmup_steps = training_params.get("warmup_steps", 0) @@ -856,7 +854,9 @@ def log_loss_valid(_task_key="Default"): if not self.multi_task: train_results = log_loss_train(loss, more_loss) - valid_results = log_loss_valid() + # valid_results = log_loss_valid() + # no run valid! + valid_results = None if self.rank == 0: log.info( format_training_message_per_task( diff --git a/deepmd/pd/utils/dataloader.py b/deepmd/pd/utils/dataloader.py index 0cb8adbc63..4da7d4e8b5 100644 --- a/deepmd/pd/utils/dataloader.py +++ b/deepmd/pd/utils/dataloader.py @@ -191,7 +191,7 @@ def construct_dataset(system): system_dataloader = DataLoader( dataset=system, num_workers=0, # Should be 0 to avoid too many threads forked - batch_sampler=system_batch_sampler, + batch_size=int(batch_size), collate_fn=collate_batch, use_buffer_reader=False, places=["cpu"], diff --git a/examples/water/dpa3/input_torch.json b/examples/water/dpa3/input_torch.json index e84f11821c..2f92a4462c 100644 --- a/examples/water/dpa3/input_torch.json +++ b/examples/water/dpa3/input_torch.json @@ -75,14 +75,14 @@ "../data/data_1", "../data/data_2" ], - "batch_size": 1, + "batch_size": 32, "_comment": "that's all" }, "validation_data": { "systems": [ "../data/data_3" ], - "batch_size": 1, + "batch_size": 32, "_comment": "that's all" }, "numb_steps": 2000,