Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Any,
Callable,
Optional,
Union,
overload,
)

Expand Down Expand Up @@ -220,11 +221,88 @@ def safe_cast_array(
return input


class ComputeInputStatsMixin:
"""Mixin class providing common compute_input_stats implementation.

This mixin implements the shared logic for computing input statistics
across all descriptor backends, while allowing backend-specific tensor
assignment through abstract methods.
"""

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[Any] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: tensor
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.

"""
from deepmd.dpmodel.utils.env_mat_stat import (
EnvMatStatSe,
)

env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()

# Backend-specific tensor assignment
self._set_stat_mean_and_stddev(mean, stddev)

@abstractmethod
def _set_stat_mean_and_stddev(self, mean, stddev) -> None:
"""Set the computed statistics to the descriptor's mean and stddev attributes.

This method should be implemented by each backend to handle the specific
tensor assignment logic for that backend.

Parameters
----------
mean : array-like
The computed mean values
stddev : array-like
The computed standard deviation values
"""
raise NotImplementedError

def get_stats(self) -> dict[str, Any]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats


__all__ = [
"DEFAULT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
"GLOBAL_NP_FLOAT_PRECISION",
"PRECISION_DICT",
"RESERVED_PRECISION_DICT",
"ComputeInputStatsMixin",
"NativeOP",
]
57 changes: 5 additions & 52 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
Union,
)
Expand All @@ -16,15 +15,13 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
ComputeInputStatsMixin,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EnvMat,
PairExcludeMask,
)
from deepmd.dpmodel.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.dpmodel.utils.network import (
LayerNorm,
NativeLayer,
Expand All @@ -36,12 +33,6 @@
from deepmd.dpmodel.utils.seed import (
child_seed,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -78,7 +69,7 @@ def xp_transpose_01342(x):

@DescriptorBlock.register("se_repformer")
@DescriptorBlock.register("se_uni")
class DescrptBlockRepformers(NativeOP, DescriptorBlock):
class DescrptBlockRepformers(NativeOP, DescriptorBlock, ComputeInputStatsMixin):
r"""
The repformer descriptor block.

Expand Down Expand Up @@ -379,54 +370,16 @@ def dim_emb(self):
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.
def _set_stat_mean_and_stddev(self, mean, stddev) -> None:
"""Set the computed statistics to the descriptor's mean and stddev attributes.

This is the dpmodel backend-specific implementation using array_api_compat.
"""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.stddev)
if not self.set_davg_zero:
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)

def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats

def reinit_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
Expand Down
46 changes: 5 additions & 41 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
from typing import (
Any,
Callable,
NoReturn,
Optional,
Union,
Expand All @@ -17,6 +16,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
ComputeInputStatsMixin,
cast_precision,
to_numpy_array,
)
Expand All @@ -26,9 +26,6 @@
NetworkCollection,
PairExcludeMask,
)
from deepmd.dpmodel.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
)
Expand All @@ -38,9 +35,6 @@
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand All @@ -52,7 +46,7 @@

@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
class DescrptSeA(NativeOP, BaseDescriptor):
class DescrptSeA(NativeOP, BaseDescriptor, ComputeInputStatsMixin):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.

Expand Down Expand Up @@ -309,41 +303,11 @@ def get_type_map(self) -> list[str]:
"""Get the name to each type of atoms."""
return self.type_map

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.
def _set_stat_mean_and_stddev(self, mean, stddev) -> None:
"""Set the computed statistics to the descriptor's mean and stddev attributes.

This is the dpmodel backend-specific implementation using array_api_compat.
"""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.dstd)
if not self.set_davg_zero:
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
Expand Down
87 changes: 87 additions & 0 deletions deepmd/pd/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Common functionality shared across Paddle descriptor implementations."""

from abc import (
abstractmethod,
)
from typing import (
Any,
Callable,
Optional,
Union,
)


class ComputeInputStatsMixin:
"""Mixin class providing common compute_input_stats implementation for Paddle backend.

This mixin implements the shared logic for computing input statistics
while allowing backend-specific tensor assignment through abstract methods.
"""

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[Any] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: paddle.Tensor
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.

"""
from deepmd.pd.utils.env_mat_stat import (
EnvMatStatSe,
)

env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()

# Backend-specific tensor assignment
self._set_stat_mean_and_stddev(mean, stddev)

@abstractmethod
def _set_stat_mean_and_stddev(self, mean, stddev) -> None:
"""Set the computed statistics to the descriptor's mean and stddev attributes.

This method should be implemented by each descriptor to handle the specific
tensor assignment logic for Paddle backend.

Parameters
----------
mean : array-like
The computed mean values
stddev : array-like
The computed standard deviation values
"""
raise NotImplementedError

def get_stats(self) -> dict[str, Any]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats
Loading