diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 5b7d96560f..0271ad5a39 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -325,11 +325,22 @@ def wrapped_sampler() -> list[dict]: atom_exclude_types = self.atom_excl.get_exclude_types() for sample in sampled: sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) return sampled self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) self.fitting_net.compute_input_stats( - wrapped_sampler, protection=self.data_stat_protect + wrapped_sampler, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, ) if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) @@ -342,6 +353,9 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return self.fitting_net.has_default_fparam() + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.fitting_net.get_default_fparam() + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.fitting_net.get_dim_aparam() diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 53d32977b0..193c5f7f63 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -530,6 +530,9 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return self.atomic_model.has_default_fparam() + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.atomic_model.get_default_fparam() + @torch.jit.export def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 4c8e90ef7c..578e1683ba 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -37,10 +37,16 @@ to_numpy_array, to_torch_tensor, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.finetune import ( get_index_between_two_maps, map_atom_exclude_types, ) +from deepmd.utils.path import ( + DPPath, +) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -57,7 +63,12 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Fitting": return super().__new__(cls) def share_params( - self, base_class: "Fitting", shared_level: int, resume: bool = False + self, + base_class: "Fitting", + shared_level: int, + model_prob: float = 1.0, + protection: float = 1e-2, + resume: bool = False, ) -> None: """ Share the parameters of self to the base_class with shared_level during multitask training. @@ -69,16 +80,164 @@ def share_params( ) if shared_level == 0: # only not share the bias_atom_e and the case_embd + # link fparam buffers + if self.numb_fparam > 0: + if not resume: + base_fparam = base_class.stats["fparam"] + assert len(base_fparam) == self.numb_fparam + for ii in range(self.numb_fparam): + base_fparam[ii] += self.get_stats()["fparam"][ii] * model_prob + fparam_avg = np.array([ii.compute_avg() for ii in base_fparam]) + fparam_std = np.array( + [ii.compute_std(protection=protection) for ii in base_fparam] + ) + fparam_inv_std = 1.0 / fparam_std + base_class.fparam_avg.copy_( + torch.tensor( + fparam_avg, + device=env.DEVICE, + dtype=base_class.fparam_avg.dtype, + ) + ) + base_class.fparam_inv_std.copy_( + torch.tensor( + fparam_inv_std, + device=env.DEVICE, + dtype=base_class.fparam_inv_std.dtype, + ) + ) + self.fparam_avg = base_class.fparam_avg + self.fparam_inv_std = base_class.fparam_inv_std + + # link aparam buffers + if self.numb_aparam > 0: + if not resume: + base_aparam = base_class.stats["aparam"] + assert len(base_aparam) == self.numb_aparam + for ii in range(self.numb_aparam): + base_aparam[ii] += self.get_stats()["aparam"][ii] * model_prob + aparam_avg = np.array([ii.compute_avg() for ii in base_aparam]) + aparam_std = np.array( + [ii.compute_std(protection=protection) for ii in base_aparam] + ) + aparam_inv_std = 1.0 / aparam_std + base_class.aparam_avg.copy_( + torch.tensor( + aparam_avg, + device=env.DEVICE, + dtype=base_class.aparam_avg.dtype, + ) + ) + base_class.aparam_inv_std.copy_( + torch.tensor( + aparam_inv_std, + device=env.DEVICE, + dtype=base_class.aparam_inv_std.dtype, + ) + ) + self.aparam_avg = base_class.aparam_avg + self.aparam_inv_std = base_class.aparam_inv_std # the following will successfully link all the params except buffers, which need manually link. for item in self._modules: self._modules[item] = base_class._modules[item] else: raise NotImplementedError + def save_to_file_fparam( + self, + stat_file_path: DPPath, + ) -> None: + """Save the statistics of fparam. + + Parameters + ---------- + stat_file_path : DPPath + The path to save the statistics of fparam. + """ + assert stat_file_path is not None + stat_file_path.mkdir(exist_ok=True, parents=True) + if len(self.stats) == 0: + raise ValueError("The statistics hasn't been computed.") + fp = stat_file_path / "fparam" + _fparam_stat = [] + for ii in range(self.numb_fparam): + _tmp_stat = self.stats["fparam"][ii] + _fparam_stat.append( + [_tmp_stat.number, _tmp_stat.sum, _tmp_stat.squared_sum] + ) + _fparam_stat = np.array(_fparam_stat) + fp.save_numpy(_fparam_stat) + log.info(f"Save fparam stats to {fp}.") + + def save_to_file_aparam( + self, + stat_file_path: DPPath, + ) -> None: + """Save the statistics of aparam. + + Parameters + ---------- + stat_file_path : DPPath + The path to save the statistics of aparam. + """ + assert stat_file_path is not None + stat_file_path.mkdir(exist_ok=True, parents=True) + if len(self.stats) == 0: + raise ValueError("The statistics hasn't been computed.") + fp = stat_file_path / "aparam" + _aparam_stat = [] + for ii in range(self.numb_aparam): + _tmp_stat = self.stats["aparam"][ii] + _aparam_stat.append( + [_tmp_stat.number, _tmp_stat.sum, _tmp_stat.squared_sum] + ) + _aparam_stat = np.array(_aparam_stat) + fp.save_numpy(_aparam_stat) + log.info(f"Save aparam stats to {fp}.") + + def restore_fparam_from_file(self, stat_file_path: DPPath) -> None: + """Load the statistics of fparam. + + Parameters + ---------- + stat_file_path : DPPath + The path to load the statistics of fparam. + """ + fp = stat_file_path / "fparam" + arr = fp.load_numpy() + assert arr.shape == (self.numb_fparam, 3) + _fparam_stat = [] + for ii in range(self.numb_fparam): + _fparam_stat.append( + StatItem(number=arr[ii][0], sum=arr[ii][1], squared_sum=arr[ii][2]) + ) + self.stats["fparam"] = _fparam_stat + log.info(f"Load fparam stats from {fp}.") + + def restore_aparam_from_file(self, stat_file_path: DPPath) -> None: + """Load the statistics of aparam. + + Parameters + ---------- + stat_file_path : DPPath + The path to load the statistics of aparam. + """ + fp = stat_file_path / "aparam" + arr = fp.load_numpy() + assert arr.shape == (self.numb_aparam, 3) + _aparam_stat = [] + for ii in range(self.numb_aparam): + _aparam_stat.append( + StatItem(number=arr[ii][0], sum=arr[ii][1], squared_sum=arr[ii][2]) + ) + self.stats["aparam"] = _aparam_stat + log.info(f"Load aparam stats from {fp}.") + def compute_input_stats( self, merged: Union[Callable[[], list[dict]], list[dict]], protection: float = 1e-2, + stat_file_path: Optional[DPPath] = None, ) -> None: """ Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. @@ -94,67 +253,101 @@ def compute_input_stats( the lazy function helps by only sampling once. protection : float Divided-by-zero protection + stat_file_path : Optional[DPPath] + The path to the stat file. """ if self.numb_fparam == 0 and self.numb_aparam == 0: # skip data statistics + self.stats = None return - if callable(merged): - sampled = merged() - else: - sampled = merged + + self.stats = {} + # stat fparam if self.numb_fparam > 0: - cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0) - cat_data = torch.reshape(cat_data, [-1, self.numb_fparam]) - fparam_avg = torch.mean(cat_data, dim=0) - fparam_std = torch.std(cat_data, dim=0, unbiased=False) - fparam_std = torch.where( - fparam_std < protection, - torch.tensor( - protection, dtype=fparam_std.dtype, device=fparam_std.device - ), - fparam_std, - ) - fparam_inv_std = 1.0 / fparam_std - self.fparam_avg.copy_( - torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype) - ) - self.fparam_inv_std.copy_( - torch.tensor( - fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype + if ( + stat_file_path is not None + and stat_file_path.is_dir() + and (stat_file_path / "fparam").is_file() + ): + self.restore_fparam_from_file(stat_file_path) + else: + sampled = merged() if callable(merged) else merged + self.stats["fparam"] = [] + cat_data = to_numpy_array( + torch.cat([frame["fparam"] for frame in sampled], dim=0) ) + cat_data = np.reshape(cat_data, [-1, self.numb_fparam]) + sumv = np.sum(cat_data, axis=0) + sumv2 = np.sum(cat_data * cat_data, axis=0) + sumn = cat_data.shape[0] + for ii in range(self.numb_fparam): + self.stats["fparam"].append( + StatItem( + number=sumn, + sum=sumv[ii], + squared_sum=sumv2[ii], + ) + ) + if stat_file_path is not None: + self.save_to_file_fparam(stat_file_path) + + fparam_avg = np.array([ii.compute_avg() for ii in self.stats["fparam"]]) + fparam_std = np.array( + [ii.compute_std(protection=protection) for ii in self.stats["fparam"]] ) + fparam_inv_std = 1.0 / fparam_std + log.info(f"fparam_avg is {fparam_avg}, fparam_inv_std is {fparam_inv_std}") + self.fparam_avg.copy_(to_torch_tensor(fparam_avg)) + self.fparam_inv_std.copy_(to_torch_tensor(fparam_inv_std)) + # stat aparam if self.numb_aparam > 0: - sys_sumv = [] - sys_sumv2 = [] - sys_sumn = [] - for ss_ in [frame["aparam"] for frame in sampled]: - ss = torch.reshape(ss_, [-1, self.numb_aparam]) - sys_sumv.append(torch.sum(ss, dim=0)) - sys_sumv2.append(torch.sum(ss * ss, dim=0)) - sys_sumn.append(ss.shape[0]) - sumv = torch.sum(torch.stack(sys_sumv), dim=0) - sumv2 = torch.sum(torch.stack(sys_sumv2), dim=0) - sumn = sum(sys_sumn) - aparam_avg = sumv / sumn - aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2) - aparam_std = torch.where( - aparam_std < protection, - torch.tensor( - protection, dtype=aparam_std.dtype, device=aparam_std.device - ), - aparam_std, + if ( + stat_file_path is not None + and stat_file_path.is_dir() + and (stat_file_path / "aparam").is_file() + ): + self.restore_aparam_from_file(stat_file_path) + else: + sampled = merged() if callable(merged) else merged + self.stats["aparam"] = [] + sys_sumv = [] + sys_sumv2 = [] + sys_sumn = [] + for ss_ in [frame["aparam"] for frame in sampled]: + ss = np.reshape(to_numpy_array(ss_), [-1, self.numb_aparam]) + sys_sumv.append(np.sum(ss, axis=0)) + sys_sumv2.append(np.sum(ss * ss, axis=0)) + sys_sumn.append(ss.shape[0]) + sumv = np.sum(np.stack(sys_sumv), axis=0) + sumv2 = np.sum(np.stack(sys_sumv2), axis=0) + sumn = sum(sys_sumn) + for ii in range(self.numb_aparam): + self.stats["aparam"].append( + StatItem( + number=sumn, + sum=sumv[ii], + squared_sum=sumv2[ii], + ) + ) + if stat_file_path is not None: + self.save_to_file_aparam(stat_file_path) + + aparam_avg = np.array([ii.compute_avg() for ii in self.stats["aparam"]]) + aparam_std = np.array( + [ii.compute_std(protection=protection) for ii in self.stats["aparam"]] ) aparam_inv_std = 1.0 / aparam_std - self.aparam_avg.copy_( - torch.tensor(aparam_avg, device=env.DEVICE, dtype=self.aparam_avg.dtype) - ) - self.aparam_inv_std.copy_( - torch.tensor( - aparam_inv_std, device=env.DEVICE, dtype=self.aparam_inv_std.dtype - ) - ) + log.info(f"aparam_avg is {aparam_avg}, aparam_inv_std is {aparam_inv_std}") + self.aparam_avg.copy_(to_torch_tensor(aparam_avg)) + self.aparam_inv_std.copy_(to_torch_tensor(aparam_inv_std)) + + def get_stats(self) -> dict[str, list[StatItem]]: + """Get the statistics of the fitting_net.""" + if self.stats is None: + raise RuntimeError("The statistics of fitting net has not been computed.") + return self.stats class GeneralFitting(Fitting): @@ -447,6 +640,9 @@ def has_default_fparam(self) -> bool: """Check if the fitting has default frame parameters.""" return self.default_fparam is not None + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.default_fparam_tensor + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.numb_aparam diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 52d2888081..d099b8b20b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -616,11 +616,37 @@ def single_model_finetune( frz_model = torch.jit.load(init_frz_model, map_location=DEVICE) self.model.load_state_dict(frz_model.state_dict()) + # Get model prob for multi-task + if self.multi_task: + self.model_prob = np.array([0.0 for key in self.model_keys]) + if training_params.get("model_prob", None) is not None: + model_prob = training_params["model_prob"] + for ii, model_key in enumerate(self.model_keys): + if model_key in model_prob: + self.model_prob[ii] += float(model_prob[model_key]) + else: + for ii, model_key in enumerate(self.model_keys): + self.model_prob[ii] += float(len(self.training_data[model_key])) + sum_prob = np.sum(self.model_prob) + assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" + self.model_prob = self.model_prob / sum_prob + # Multi-task share params if shared_links is not None: + _data_stat_protect = np.array( + [ + model_params["model_dict"][ii].get("data_stat_protect", 1e-2) + for ii in model_params["model_dict"] + ] + ) + assert np.allclose(_data_stat_protect, _data_stat_protect[0]), ( + "Model key 'data_stat_protect' must be the same in each branch when multitask!" + ) self.wrapper.share_params( shared_links, resume=(resuming and not self.finetune_update_stat) or self.rank != 0, + model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + data_stat_protect=_data_stat_protect[0], ) if dist.is_available() and dist.is_initialized(): @@ -670,21 +696,6 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") - # Get model prob for multi-task - if self.multi_task: - self.model_prob = np.array([0.0 for key in self.model_keys]) - if training_params.get("model_prob", None) is not None: - model_prob = training_params["model_prob"] - for ii, model_key in enumerate(self.model_keys): - if model_key in model_prob: - self.model_prob[ii] += float(model_prob[model_key]) - else: - for ii, model_key in enumerate(self.model_keys): - self.model_prob[ii] += float(len(self.training_data[model_key])) - sum_prob = np.sum(self.model_prob) - assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" - self.model_prob = self.model_prob / sum_prob - # Tensorboard self.enable_tensorboard = training_params.get("tensorboard", False) self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") @@ -1337,12 +1348,18 @@ def print_on_training( def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: additional_data_requirement = [] if _model.get_dim_fparam() > 0: + _fparam_default = ( + _model.get_default_fparam().cpu().numpy() + if _model.has_default_fparam() + else 0.0 + ) fparam_requirement_items = [ DataRequirementItem( "fparam", _model.get_dim_fparam(), atomic=False, must=not _model.has_default_fparam(), + default=_fparam_default, ) ] additional_data_requirement += fparam_requirement_items diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 392f928b0d..c65787958a 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -60,7 +60,13 @@ def __init__( self.loss[task_key] = loss[task_key] self.inference_only = self.loss is None - def share_params(self, shared_links: dict[str, Any], resume: bool = False) -> None: + def share_params( + self, + shared_links: dict[str, Any], + model_key_prob_map: dict, + data_stat_protect: float = 1e-2, + resume: bool = False, + ) -> None: """ Share the parameters of classes following rules defined in shared_links during multitask training. If not start from checkpoint (resume is False), @@ -130,8 +136,16 @@ def share_params(self, shared_links: dict[str, Any], resume: bool = False) -> No link_class = self.model[ model_key_link ].atomic_model.__getattr__(class_type_link) + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) link_class.share_params( - base_class, shared_level_link, resume=resume + base_class, + shared_level_link, + model_prob=frac_prob, + protection=data_stat_protect, + resume=resume, ) log.warning( f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index ecc0b7b62f..3fa4d7d410 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -48,6 +48,13 @@ def __add__(self, other: "StatItem") -> "StatItem": squared_sum=self.squared_sum + other.squared_sum, ) + def __mul__(self, scalar: float) -> "StatItem": + return StatItem( + number=self.number * scalar, + sum=self.sum * scalar, + squared_sum=self.squared_sum * scalar, + ) + def compute_avg(self, default: float = 0) -> float: """Compute the average of the environment matrix. diff --git a/source/tests/pt/model/water/data/data_1/set.000/box.npy b/source/tests/pt/model/water/data/data_1/set.000/box.npy new file mode 100644 index 0000000000..6ad2de625b Binary files /dev/null and b/source/tests/pt/model/water/data/data_1/set.000/box.npy differ diff --git a/source/tests/pt/model/water/data/data_1/set.000/coord.npy b/source/tests/pt/model/water/data/data_1/set.000/coord.npy new file mode 100644 index 0000000000..8bd448b125 Binary files /dev/null and b/source/tests/pt/model/water/data/data_1/set.000/coord.npy differ diff --git a/source/tests/pt/model/water/data/data_1/set.000/energy.npy b/source/tests/pt/model/water/data/data_1/set.000/energy.npy new file mode 100644 index 0000000000..d03db103f5 Binary files /dev/null and b/source/tests/pt/model/water/data/data_1/set.000/energy.npy differ diff --git a/source/tests/pt/model/water/data/data_1/set.000/force.npy b/source/tests/pt/model/water/data/data_1/set.000/force.npy new file mode 100644 index 0000000000..10b2ab83a2 Binary files /dev/null and b/source/tests/pt/model/water/data/data_1/set.000/force.npy differ diff --git a/source/tests/pt/model/water/data/data_1/type.raw b/source/tests/pt/model/water/data/data_1/type.raw new file mode 100644 index 0000000000..97e8fdfcf8 --- /dev/null +++ b/source/tests/pt/model/water/data/data_1/type.raw @@ -0,0 +1,192 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/model/water/data/data_1/type_map.raw b/source/tests/pt/model/water/data/data_1/type_map.raw new file mode 100644 index 0000000000..e900768b1d --- /dev/null +++ b/source/tests/pt/model/water/data/data_1/type_map.raw @@ -0,0 +1,2 @@ +O +H diff --git a/source/tests/pt/test_fitting_stat.py b/source/tests/pt/test_fitting_stat.py index bc02b539a0..7807523221 100644 --- a/source/tests/pt/test_fitting_stat.py +++ b/source/tests/pt/test_fitting_stat.py @@ -1,18 +1,52 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import tempfile import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + NoReturn, +) +import h5py import numpy as np +import torch +from deepmd.pt.entrypoints.main import ( + get_trainer, +) from deepmd.pt.model.descriptor import ( DescrptSeA, ) from deepmd.pt.model.task import ( EnergyFittingNet, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, ) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.path import ( + DPPath, +) + +from .model.test_permutation import ( + model_se_e2_a, +) def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): @@ -71,16 +105,18 @@ def _brute_aparam_pt(data, ndim): class TestEnerFittingStat(unittest.TestCase): + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + def test(self) -> None: descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) - fitting = EnergyFittingNet( - descrpt.get_ntypes(), - descrpt.get_dim_out(), - neuron=[240, 240, 240], - resnet_dt=True, - numb_fparam=3, - numb_aparam=3, - ) avgs = [0, 10, 100] stds = [2, 0.4, 0.00001] sys_natoms = [10, 100] @@ -88,11 +124,23 @@ def test(self) -> None: all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) - fitting.compute_input_stats(all_data, protection=1e-2) frefs_inv = 1.0 / frefs arefs_inv = 1.0 / arefs frefs_inv[frefs_inv > 100] = 100 arefs_inv[arefs_inv > 100] = 100 + + # 1. test fitting stat is applied + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ) + fitting.compute_input_stats( + all_data, protection=1e-2, stat_file_path=self.stat_file_path + ) np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg)) np.testing.assert_almost_equal( frefs_inv, to_numpy_array(fitting.fparam_inv_std) @@ -101,3 +149,347 @@ def test(self) -> None: np.testing.assert_almost_equal( arefs_inv, to_numpy_array(fitting.aparam_inv_std) ) + del fitting + + # 2. test fitting stat writing to file is correct + concat_fparam = np.concatenate( + [ + to_numpy_array(all_data[ii]["fparam"].reshape(-1, 3)) + for ii in range(len(sys_nframes)) + ] + ) + concat_aparam = np.concatenate( + [ + to_numpy_array(all_data[ii]["aparam"].reshape(-1, 3)) + for ii in range(len(sys_nframes)) + ] + ) + fparam_stat = (self.stat_file_path / "fparam").load_numpy() + aparam_stat = (self.stat_file_path / "aparam").load_numpy() + np.testing.assert_almost_equal( + fparam_stat[:, 0], np.array([concat_fparam.shape[0]] * 3) + ) + np.testing.assert_almost_equal(fparam_stat[:, 1], np.sum(concat_fparam, axis=0)) + np.testing.assert_almost_equal( + fparam_stat[:, 2], np.sum(concat_fparam**2, axis=0) + ) + np.testing.assert_almost_equal( + aparam_stat[:, 0], np.array([concat_aparam.shape[0]] * 3) + ) + np.testing.assert_almost_equal(aparam_stat[:, 1], np.sum(concat_aparam, axis=0)) + np.testing.assert_almost_equal( + aparam_stat[:, 2], np.sum(concat_aparam**2, axis=0) + ) + + # 3. test fitting stat load from file + def raise_error() -> NoReturn: + raise RuntimeError + + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ) + fitting.compute_input_stats( + raise_error, protection=1e-2, stat_file_path=self.stat_file_path + ) + np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg)) + np.testing.assert_almost_equal( + frefs_inv, to_numpy_array(fitting.fparam_inv_std) + ) + np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg)) + np.testing.assert_almost_equal( + arefs_inv, to_numpy_array(fitting.aparam_inv_std) + ) + + +def get_weighted_fitting_stat(model_prob: list, *stat_arrays, protection: float): + n_arrays = len(stat_arrays) + assert len(model_prob) == n_arrays + + nframes = [stat.shape[0] for stat in stat_arrays] + sums = [stat.sum(axis=0) for stat in stat_arrays] + squared_sums = [(stat**2).sum(axis=0) for stat in stat_arrays] + + weighted_sum = sum(model_prob[i] * sums[i] for i in range(n_arrays)) + total_weighted_frames = sum(model_prob[i] * nframes[i] for i in range(n_arrays)) + weighted_avg = weighted_sum / total_weighted_frames + + weighted_square_sum = sum(model_prob[i] * squared_sums[i] for i in range(n_arrays)) + weighted_square_avg = weighted_square_sum / total_weighted_frames + weighted_std = np.sqrt(weighted_square_avg - weighted_avg**2) + weighted_std = np.where(weighted_std < protection, protection, weighted_std) + + return weighted_avg, weighted_std + + +class TestMultiTaskFittingStat(unittest.TestCase): + def setUp(self) -> None: + multitask_sharefit_template_json = str( + Path(__file__).parent / "water/multitask_sharefit.json" + ) + with open(multitask_sharefit_template_json) as f: + multitask_se_e2_a = json.load(f) + multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + self.data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.data_file_without_fparam = [ + str(Path(__file__).parent / "water/data/data_1") + ] + self.data_file_single = [str(Path(__file__).parent / "water/data/single")] + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + + self.config = multitask_se_e2_a + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["model"]["shared_dict"]["my_fitting"]["numb_fparam"] = 2 + self.default_fparam = [1.0, 0.0] + self.config["model"]["shared_dict"]["my_fitting"]["default_fparam"] = ( + self.default_fparam + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + self.origin_config = deepcopy(self.config) + + def test_sharefitting_with_fparam(self): + # test multitask training with fparam + self.config = deepcopy(self.origin_config) + model_prob = [0.3, 0.7] + self.config["training"]["model_prob"]["model_1"] = model_prob[0] + self.config["training"]["model_prob"]["model_2"] = model_prob[1] + + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + self.data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = self.data_file + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + self.data_file_single + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = self.data_file_single + self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 100 + + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + self.config = update_deepmd_input(self.config, warning=True) + self.config = normalize(self.config, multi_task=True) + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + + # check fparam shared + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H B/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H B/fparam") + fparam_data1 = np.load(f"{self.data_file[0]}/set.000/fparam.npy") + fparam_data2 = np.load(f"{self.data_file_single[0]}/set.000/fparam.npy") + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = get_weighted_fitting_stat( + model_prob, fparam_data1, fparam_data2, protection=1e-2 + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + def test_sharefitting_using_default_fparam(self): + # test multitask training with fparam + self.config = deepcopy(self.origin_config) + # add model3 + self.config["model"]["model_dict"]["model_3"] = deepcopy( + self.config["model"]["model_dict"]["model_2"] + ) + self.config["loss_dict"]["model_3"] = deepcopy( + self.config["loss_dict"]["model_2"] + ) + self.config["training"]["model_prob"]["model_3"] = deepcopy( + self.config["training"]["model_prob"]["model_2"] + ) + self.config["training"]["data_dict"]["model_3"] = deepcopy( + self.config["training"]["data_dict"]["model_2"] + ) + self.config["training"]["data_dict"]["model_3"]["stat_file"] = self.config[ + "training" + ]["data_dict"]["model_3"]["stat_file"].replace("model_2", "model_3") + self.config["model"]["shared_dict"]["my_fitting"]["dim_case_embd"] = 3 + + model_prob = [0.1, 0.3, 0.6] + self.config["training"]["model_prob"]["model_1"] = model_prob[0] + self.config["training"]["model_prob"]["model_2"] = model_prob[1] + self.config["training"]["model_prob"]["model_3"] = model_prob[2] + + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + self.data_file_without_fparam + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = self.data_file_without_fparam + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + self.data_file_single + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = self.data_file_single + self.config["training"]["data_dict"]["model_3"]["stat_file"] = ( + f"{self.stat_files}/model_3" + ) + self.config["training"]["data_dict"]["model_3"]["training_data"]["systems"] = ( + self.data_file + ) + self.config["training"]["data_dict"]["model_3"]["validation_data"][ + "systems" + ] = self.data_file + data_stat_protect = 5e-3 + self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 3 + self.config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 100 + self.config["model"]["model_dict"]["model_1"]["data_stat_protect"] = ( + data_stat_protect + ) + self.config["model"]["model_dict"]["model_2"]["data_stat_protect"] = ( + data_stat_protect + ) + self.config["model"]["model_dict"]["model_3"]["data_stat_protect"] = ( + data_stat_protect + ) + + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + self.config = update_deepmd_input(self.config, warning=True) + self.config = normalize(self.config, multi_task=True) + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + + # check fparam shared + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H B/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H B/fparam") + fparam_stat_model3 = np.load(f"{self.stat_files}/model_3/O H B/fparam") + fparam_data1 = np.array([self.default_fparam]).repeat(3, axis=0) + fparam_data2 = np.load(f"{self.data_file_single[0]}/set.000/fparam.npy") + fparam_data3 = np.load(f"{self.data_file[0]}/set.000/fparam.npy") + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 0], [fparam_data3.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 1], fparam_data3.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 2], (fparam_data3**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = get_weighted_fitting_stat( + model_prob, + fparam_data1, + fparam_data2, + fparam_data3, + protection=data_stat_protect, + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "checkpoint"]: + os.remove(f) + if f in [self.stat_files]: + shutil.rmtree(f)