Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return False

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return None

def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.fitting.has_default_fparam()

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.fitting.get_default_fparam()

def get_sel_type(self) -> list[int]:
"""Get the selected atom types of this model.

Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def has_default_fparam(self) -> bool:
"""Check if the fitting has default frame parameters."""
return self.default_fparam is not None

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.default_fparam

def get_sel_type(self) -> list[int]:
"""Get the selected atom types of this model.

Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.atomic_model.has_default_fparam()

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.atomic_model.get_default_fparam()

def get_sel_type(self) -> list[int]:
"""Get the selected atom types of this model.

Expand Down
12 changes: 12 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def _eval_model(
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
elif self.dp.has_default_fparam():
# JAX (XLA) requires static shapes, so default must be implemented here
default_fparam = self.dp.get_default_fparam()
assert default_fparam is not None
fparam_input = np.tile(
np.array(default_fparam, dtype=GLOBAL_NP_FLOAT_PRECISION),
(nframes, 1),
)
else:
fparam_input = None
if aparam is not None:
Expand Down Expand Up @@ -433,3 +441,7 @@ def get_model(self) -> Any:
The JAX model as BaseModel instance.
"""
return self.dp

def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.dp.has_default_fparam()
17 changes: 17 additions & 0 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,23 @@ def has_message_passing() -> tf.Tensor:
return tf.constant(model.has_message_passing(), dtype=tf.bool)

tf_model.has_message_passing = has_message_passing

@tf.function
def has_default_fparam() -> tf.Tensor:
return tf.constant(model.has_default_fparam(), dtype=tf.bool)

tf_model.has_default_fparam = has_default_fparam

@tf.function
def get_default_fparam() -> tf.Tensor:
default_fparam = model.get_default_fparam()
if default_fparam is None:
return tf.constant([], dtype=tf.double)
else:
return tf.constant(default_fparam, dtype=tf.double)

tf_model.get_default_fparam = get_default_fparam

tf.saved_model.save(
tf_model,
model_file,
Expand Down
17 changes: 17 additions & 0 deletions deepmd/jax/jax2tf/tfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def __init__(
self.min_nbor_dist = None
self.sel = self.model.get_sel().numpy().tolist()
self.model_def_script = self.model.get_model_def_script().numpy().decode()
if hasattr(self.model, "has_default_fparam"):
# No attrs before v3.1.2
self._has_default_fparam = self.model.has_default_fparam().numpy().item()
else:
self._has_default_fparam = False
if hasattr(self.model, "get_default_fparam"):
self.default_fparam = self.model.get_default_fparam().numpy().tolist()
else:
self.default_fparam = None

def __call__(
self,
Expand Down Expand Up @@ -331,3 +340,11 @@ def get_model(cls, model_params: dict) -> "TFModelWrapper":
The model
"""
raise NotImplementedError("Not implemented")

def has_default_fparam(self) -> bool:
"""Check whether the model has default frame parameters."""
return self._has_default_fparam

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.default_fparam
13 changes: 13 additions & 0 deletions deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __init__(
mixed_types: bool,
min_nbor_dist: float | None,
sel: list[int],
# new in v3.1.1
has_default_fparam: bool = False,
default_fparam: list[float] | None = None,
) -> None:
self._call_lower = jax_export.deserialize(stablehlo).call
self._call_lower_atomic_virial = jax_export.deserialize(
Expand All @@ -79,6 +82,8 @@ def __init__(
self.min_nbor_dist = min_nbor_dist
self.sel = sel
self.model_def_script = model_def_script
self._has_default_fparam = has_default_fparam
self.default_fparam = default_fparam

def __call__(
self,
Expand Down Expand Up @@ -327,3 +332,11 @@ def get_model(cls, model_params: dict) -> "BaseModel":
The model
"""
raise NotImplementedError("Not implemented")

def has_default_fparam(self) -> bool:
"""Check whether the model has default frame parameters."""
return self._has_default_fparam

def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.default_fparam
2 changes: 2 additions & 0 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def call_lower_with_fixed_do_atomic_virial(
"mixed_types": model.mixed_types(),
"min_nbor_dist": model.get_min_nbor_dist(),
"sel": model.get_sel(),
"has_default_fparam": model.has_default_fparam(),
"default_fparam": model.get_default_fparam(),
}
save_dp_model(filename=model_file, model_dict=data)
elif model_file.endswith(".savedmodel"):
Expand Down