diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 6fb48eb659..2353e207a3 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -95,6 +95,10 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return False + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return None + def reinit_atom_exclude( self, exclude_types: list[int] = [], diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 07a02ad56b..0f5b12bc9c 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -240,6 +240,10 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return self.fitting.has_default_fparam() + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return self.fitting.get_default_fparam() + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index c61ab234b1..a4089468f3 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -304,6 +304,10 @@ def has_default_fparam(self) -> bool: """Check if the fitting has default frame parameters.""" return self.default_fparam is not None + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return self.default_fparam + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 3a88aac1e4..cc9dd12fc5 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -567,6 +567,10 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return self.atomic_model.has_default_fparam() + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return self.atomic_model.get_default_fparam() + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 1e29ee1c78..10c4b7d7a5 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -354,6 +354,14 @@ def _eval_model( box_input = None if fparam is not None: fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + elif self.dp.has_default_fparam(): + # JAX (XLA) requires static shapes, so default must be implemented here + default_fparam = self.dp.get_default_fparam() + assert default_fparam is not None + fparam_input = np.tile( + np.array(default_fparam, dtype=GLOBAL_NP_FLOAT_PRECISION), + (nframes, 1), + ) else: fparam_input = None if aparam is not None: @@ -433,3 +441,7 @@ def get_model(self) -> Any: The JAX model as BaseModel instance. """ return self.dp + + def has_default_fparam(self) -> bool: + """Check if the model has default frame parameters.""" + return self.dp.has_default_fparam() diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index e819ebf65a..31d0d7eb82 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -319,6 +319,23 @@ def has_message_passing() -> tf.Tensor: return tf.constant(model.has_message_passing(), dtype=tf.bool) tf_model.has_message_passing = has_message_passing + + @tf.function + def has_default_fparam() -> tf.Tensor: + return tf.constant(model.has_default_fparam(), dtype=tf.bool) + + tf_model.has_default_fparam = has_default_fparam + + @tf.function + def get_default_fparam() -> tf.Tensor: + default_fparam = model.get_default_fparam() + if default_fparam is None: + return tf.constant([], dtype=tf.double) + else: + return tf.constant(default_fparam, dtype=tf.double) + + tf_model.get_default_fparam = get_default_fparam + tf.saved_model.save( tf_model, model_file, diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 85115547d7..1c968c8f41 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -69,6 +69,15 @@ def __init__( self.min_nbor_dist = None self.sel = self.model.get_sel().numpy().tolist() self.model_def_script = self.model.get_model_def_script().numpy().decode() + if hasattr(self.model, "has_default_fparam"): + # No attrs before v3.1.2 + self._has_default_fparam = self.model.has_default_fparam().numpy().item() + else: + self._has_default_fparam = False + if hasattr(self.model, "get_default_fparam"): + self.default_fparam = self.model.get_default_fparam().numpy().tolist() + else: + self.default_fparam = None def __call__( self, @@ -331,3 +340,11 @@ def get_model(cls, model_params: dict) -> "TFModelWrapper": The model """ raise NotImplementedError("Not implemented") + + def has_default_fparam(self) -> bool: + """Check whether the model has default frame parameters.""" + return self._has_default_fparam + + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return self.default_fparam diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 47959dd130..7eb4e2c4b3 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -58,6 +58,9 @@ def __init__( mixed_types: bool, min_nbor_dist: float | None, sel: list[int], + # new in v3.1.1 + has_default_fparam: bool = False, + default_fparam: list[float] | None = None, ) -> None: self._call_lower = jax_export.deserialize(stablehlo).call self._call_lower_atomic_virial = jax_export.deserialize( @@ -79,6 +82,8 @@ def __init__( self.min_nbor_dist = min_nbor_dist self.sel = sel self.model_def_script = model_def_script + self._has_default_fparam = has_default_fparam + self.default_fparam = default_fparam def __call__( self, @@ -327,3 +332,11 @@ def get_model(cls, model_params: dict) -> "BaseModel": The model """ raise NotImplementedError("Not implemented") + + def has_default_fparam(self) -> bool: + """Check whether the model has default frame parameters.""" + return self._has_default_fparam + + def get_default_fparam(self) -> list[float] | None: + """Get the default frame parameters.""" + return self.default_fparam diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 6a3c839608..5d3432aab8 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -133,6 +133,8 @@ def call_lower_with_fixed_do_atomic_virial( "mixed_types": model.mixed_types(), "min_nbor_dist": model.get_min_nbor_dist(), "sel": model.get_sel(), + "has_default_fparam": model.has_default_fparam(), + "default_fparam": model.get_default_fparam(), } save_dp_model(filename=model_file, model_dict=data) elif model_file.endswith(".savedmodel"):