Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,22 @@ def wrapped_sampler() -> list[dict]:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = default_fparam.repeat(nframe, 1)
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
wrapped_sampler,
protection=self.data_stat_protect,
stat_file_path=stat_file_path,
)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
Expand All @@ -342,6 +353,9 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.fitting_net.has_default_fparam()

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.fitting_net.get_default_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.fitting_net.get_dim_aparam()
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.atomic_model.has_default_fparam()

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.atomic_model.get_default_fparam()

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
Expand Down
Loading