Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)

import torch
import numpy as np

from deepmd.dpmodel import (
FittingOutputDef,
Expand Down Expand Up @@ -325,11 +326,20 @@ def wrapped_sampler() -> list[dict]:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = default_fparam.repeat(nframe, 1)
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
wrapped_sampler, protection=self.data_stat_protect, stat_file_path=stat_file_path
)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
Expand All @@ -342,6 +352,9 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.fitting_net.has_default_fparam()

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.fitting_net.get_default_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.fitting_net.get_dim_aparam()
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)

import torch
import numpy as np

from deepmd.dpmodel import (
ModelOutputDef,
Expand Down Expand Up @@ -530,6 +531,10 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.atomic_model.has_default_fparam()

@torch.jit.export
def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.atomic_model.get_default_fparam()

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
Expand Down
258 changes: 207 additions & 51 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Optional,
Union,
List,
)

import numpy as np
Expand Down Expand Up @@ -41,6 +42,12 @@
get_index_between_two_maps,
map_atom_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand All @@ -57,7 +64,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Fitting":
return super().__new__(cls)

def share_params(
self, base_class: "Fitting", shared_level: int, resume: bool = False
self, base_class: "Fitting", shared_level: int, model_prob=1.0, protection=1e-2, resume: bool = False
) -> None:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
Expand All @@ -69,16 +76,140 @@ def share_params(
)
if shared_level == 0:
# only not share the bias_atom_e and the case_embd
# link fparam buffers
if self.numb_fparam > 0:
if not resume:
base_fparam = base_class.stats["fparam"]
assert len(base_fparam) == self.numb_fparam
for ii in range(self.numb_fparam):
base_fparam[ii] += self.get_stats()["fparam"][ii] * model_prob
fparam_avg = np.array([ii.compute_avg() for ii in base_fparam])
fparam_std = np.array([ii.compute_std(protection=protection) for ii in base_fparam])
fparam_inv_std = 1.0 / fparam_std
base_class.fparam_avg.copy_(
torch.tensor(
fparam_avg, device=env.DEVICE, dtype=base_class.fparam_avg.dtype
)
)
base_class.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=base_class.fparam_inv_std.dtype
)
)
self.fparam_avg = base_class.fparam_avg
self.fparam_inv_std = base_class.fparam_inv_std

# link aparam buffers
if self.numb_aparam > 0:
if not resume:
base_aparam = base_class.stats["aparam"]
assert len(base_aparam) == self.numb_aparam
for ii in range(self.numb_aparam):
base_aparam[ii] += self.get_stats()["aparam"][ii] * model_prob
aparam_avg = np.array([ii.compute_avg() for ii in base_aparam])
aparam_std = np.array([ii.compute_std(protection=protection) for ii in base_aparam])
aparam_inv_std = 1.0 / aparam_std
base_class.aparam_avg.copy_(
torch.tensor(
aparam_avg, device=env.DEVICE, dtype=base_class.aparam_avg.dtype
)
)
base_class.aparam_inv_std.copy_(
torch.tensor(
aparam_inv_std, device=env.DEVICE, dtype=base_class.aparam_inv_std.dtype
)
)
self.aparam_avg = base_class.aparam_avg
self.aparam_inv_std = base_class.aparam_inv_std
# the following will successfully link all the params except buffers, which need manually link.
for item in self._modules:
self._modules[item] = base_class._modules[item]
else:
raise NotImplementedError

def save_to_file_fparam(
self,
stat_file_path: DPPath,
) -> None:
"""Save the statistics of fparam.
Parameters
----------
path : DPPath
The path to save the statistics of fparam.
"""
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
if len(self.stats) == 0:
raise ValueError("The statistics hasn't been computed.")
fp = stat_file_path / "fparam"
_fparam_stat = []
for ii in range(self.numb_fparam):
_tmp_stat = self.stats["fparam"][ii]
_fparam_stat.append([_tmp_stat.number, _tmp_stat.sum, _tmp_stat.squared_sum])
_fparam_stat = np.array(_fparam_stat)
fp.save_numpy(_fparam_stat)
log.info(f"Save fparam stats to {fp}.")

def save_to_file_aparam(
self,
stat_file_path: DPPath,
) -> None:
"""Save the statistics of aparam.
Parameters
----------
path : DPPath
The path to save the statistics of aparam.
"""
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
if len(self.stats) == 0:
raise ValueError("The statistics hasn't been computed.")
fp = stat_file_path / "aparam"
_aparam_stat = []
for ii in range(self.numb_aparam):
_tmp_stat = self.stats["aparam"][ii]
_aparam_stat.append([_tmp_stat.number, _tmp_stat.sum, _tmp_stat.squared_sum])
_aparam_stat = np.array(_aparam_stat)
fp.save_numpy(_aparam_stat)
log.info(f"Save aparam stats to {fp}.")

def restore_fparam_from_file(self, stat_file_path: DPPath) -> None:
"""Load the statistics of fparam.
Parameters
----------
path : DPPath
The path to load the statistics of fparam.
"""
fp = stat_file_path / "fparam"
arr = fp.load_numpy()
assert arr.shape == (self.numb_fparam, 3)
_fparam_stat = []
for ii in range(self.numb_fparam):
_fparam_stat.append(StatItem(number=arr[ii][0], sum=arr[ii][1], squared_sum=arr[ii][2]))
self.stats["fparam"] = _fparam_stat
log.info(f"Load fparam stats from {fp}.")

def restore_aparam_from_file(self, stat_file_path: DPPath) -> None:
"""Load the statistics of aparam.
Parameters
----------
path : DPPath
The path to load the statistics of aparam.
"""
fp = stat_file_path / "aparam"
arr = fp.load_numpy()
assert arr.shape == (self.numb_aparam, 3)
_aparam_stat = []
for ii in range(self.numb_aparam):
_aparam_stat.append(StatItem(number=arr[ii][0], sum=arr[ii][1], squared_sum=arr[ii][2]))
self.stats["aparam"] = _aparam_stat
log.info(f"Load aparam stats from {fp}.")

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
protection: float = 1e-2,
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
Expand All @@ -94,67 +225,89 @@ def compute_input_stats(
the lazy function helps by only sampling once.
protection : float
Divided-by-zero protection
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
if self.numb_fparam == 0 and self.numb_aparam == 0:
# skip data statistics
self.stats = None
return
if callable(merged):
sampled = merged()
else:
sampled = merged

self.stats = {}

# stat fparam
if self.numb_fparam > 0:
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
fparam_avg = torch.mean(cat_data, dim=0)
fparam_std = torch.std(cat_data, dim=0, unbiased=False)
fparam_std = torch.where(
fparam_std < protection,
torch.tensor(
protection, dtype=fparam_std.dtype, device=fparam_std.device
),
fparam_std,
)
if stat_file_path is not None and stat_file_path.is_dir():
self.restore_fparam_from_file(stat_file_path)
else:
sampled = merged() if callable(merged) else merged
self.stats["fparam"] = []
cat_data = to_numpy_array(torch.cat([frame["fparam"] for frame in sampled], dim=0))
cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
sumv = np.sum(cat_data, axis=0)
sumv2 = np.sum(cat_data * cat_data, axis=0)
sumn = cat_data.shape[0]
for ii in range(self.numb_fparam):
self.stats["fparam"].append(
StatItem(
number=sumn,
sum=sumv[ii],
squared_sum=sumv2[ii],
)
)
if stat_file_path is not None:
self.save_to_file_fparam(stat_file_path)

fparam_avg = np.array([ii.compute_avg() for ii in self.stats["fparam"]])
fparam_std = np.array([ii.compute_std(protection=protection) for ii in self.stats["fparam"]])
fparam_inv_std = 1.0 / fparam_std
self.fparam_avg.copy_(
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
)
self.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
)
)
log.info(f"fparam_avg is {fparam_avg}, fparam_inv_std is {fparam_inv_std}")
self.fparam_avg.copy_(to_torch_tensor(fparam_avg))
self.fparam_inv_std.copy_(to_torch_tensor(fparam_inv_std))

# stat aparam
if self.numb_aparam > 0:
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
for ss_ in [frame["aparam"] for frame in sampled]:
ss = torch.reshape(ss_, [-1, self.numb_aparam])
sys_sumv.append(torch.sum(ss, dim=0))
sys_sumv2.append(torch.sum(ss * ss, dim=0))
sys_sumn.append(ss.shape[0])
sumv = torch.sum(torch.stack(sys_sumv), dim=0)
sumv2 = torch.sum(torch.stack(sys_sumv2), dim=0)
sumn = sum(sys_sumn)
aparam_avg = sumv / sumn
aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
aparam_std = torch.where(
aparam_std < protection,
torch.tensor(
protection, dtype=aparam_std.dtype, device=aparam_std.device
),
aparam_std,
)
if stat_file_path is not None and stat_file_path.is_dir():
self.restore_aparam_from_file(stat_file_path)
else:
sampled = merged() if callable(merged) else merged
self.stats["aparam"] = []
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
for ss_ in [frame["aparam"] for frame in sampled]:
ss = np.reshape(to_numpy_array(ss_), [-1, self.numb_aparam])
sys_sumv.append(np.sum(ss, axis=0))
sys_sumv2.append(np.sum(ss * ss, axis=0))
sys_sumn.append(ss.shape[0])
sumv = np.sum(np.stack(sys_sumv), axis=0)
sumv2 = np.sum(np.stack(sys_sumv2), axis=0)
sumn = sum(sys_sumn)
for ii in range(self.numb_aparam):
self.stats["aparam"].append(
StatItem(
number=sumn,
sum=sumv[ii],
squared_sum=sumv2[ii],
)
)
if stat_file_path is not None:
self.save_to_file_aparam(stat_file_path)

aparam_avg = np.array([ii.compute_avg() for ii in self.stats["aparam"]])
aparam_std = np.array([ii.compute_std(protection=protection) for ii in self.stats["aparam"]])
aparam_inv_std = 1.0 / aparam_std
self.aparam_avg.copy_(
torch.tensor(aparam_avg, device=env.DEVICE, dtype=self.aparam_avg.dtype)
)
self.aparam_inv_std.copy_(
torch.tensor(
aparam_inv_std, device=env.DEVICE, dtype=self.aparam_inv_std.dtype
)
log.info(f"aparam_avg is {aparam_avg}, aparam_inv_std is {aparam_inv_std}")
self.aparam_avg.copy_(to_torch_tensor(aparam_avg))
self.aparam_inv_std.copy_(to_torch_tensor(aparam_inv_std))

def get_stats(self) -> dict[str, List[StatItem]]:
"""Get the statistics of the fitting_net."""
if self.stats is None:
raise RuntimeError(
"The statistics of fitting net has not been computed."
)
return self.stats


class GeneralFitting(Fitting):
Expand Down Expand Up @@ -447,6 +600,9 @@ def has_default_fparam(self) -> bool:
"""Check if the fitting has default frame parameters."""
return self.default_fparam is not None

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.default_fparam_tensor

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.numb_aparam
Expand Down
Loading
Loading