Skip to content

Commit 38815b3

Browse files
authored
feat(jax): export call_lower to SavedModel via jax2tf (#4254)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the TensorFlow SavedModel format, allowing users to handle additional model file types. - Introduced a new TensorFlow model wrapper class for enhanced integration with JAX functionalities. - **Bug Fixes** - Improved error handling for unsupported file formats during model deserialization. - **Documentation** - Updated backend documentation to reflect new file extensions and clarify backend capabilities. - **Tests** - Enhanced test structure for better clarity and maintainability regarding backend handling. - Added a new job for testing TensorFlow 2 in eager mode within the testing workflow. - Introduced a conditional skip for tests based on TensorFlow 2 compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 7aaf284 commit 38815b3

File tree

11 files changed

+568
-22
lines changed

11 files changed

+568
-22
lines changed

.github/workflows/test_python.yml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,23 @@ jobs:
2525
python-version: ${{ matrix.python }}
2626
- run: python -m pip install -U uv
2727
- run: |
28-
source/install/uv_with_retry.sh pip install --system mpich
28+
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu
2929
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
30+
export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])')
3031
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
31-
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
32+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py
33+
source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation
3234
env:
3335
# Please note that uv has some issues with finding
3436
# existing TensorFlow package. Currently, it uses
3537
# TensorFlow in the build dependency, but if it
3638
# changes, setting `TENSORFLOW_ROOT`.
37-
TENSORFLOW_VERSION: 2.16.1
3839
DP_ENABLE_PYTORCH: 1
3940
DP_BUILD_TESTING: 1
40-
UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/njzjz/simple https://pypi.anaconda.org/mpi4py/simple"
41+
UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/mpi4py/simple"
42+
HOROVOD_WITH_TENSORFLOW: 1
43+
HOROVOD_WITHOUT_PYTORCH: 1
44+
HOROVOD_WITH_MPI: 1
4145
- run: dp --version
4246
- name: Get durations from cache
4347
uses: actions/cache@v4
@@ -53,6 +57,12 @@ jobs:
5357
- run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration
5458
env:
5559
NUM_WORKERS: 0
60+
- name: Test TF2 eager mode
61+
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
62+
env:
63+
NUM_WORKERS: 0
64+
DP_TEST_TF2_ONLY: 1
65+
if: matrix.group == 1
5666
- run: mv .test_durations .test_durations_${{ matrix.group }}
5767
- name: Upload partial durations
5868
uses: actions/upload-artifact@v4

deepmd/backend/jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class JAXBackend(Backend):
3838
| Backend.Feature.NEIGHBOR_STAT
3939
)
4040
"""The features of the backend."""
41-
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
41+
suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"]
4242
"""The suffixes of the backend."""
4343

4444
def is_available(self) -> bool:

deepmd/jax/infer/deep_eval.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,24 @@ def __init__(
9090
self.output_def = output_def
9191
self.model_path = model_file
9292

93-
model_data = load_dp_model(model_file)
94-
self.dp = HLO(
95-
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
96-
stablehlo_atomic_virial=model_data["@variables"][
97-
"stablehlo_atomic_virial"
98-
].tobytes(),
99-
model_def_script=model_data["model_def_script"],
100-
**model_data["constants"],
101-
)
93+
if model_file.endswith(".hlo"):
94+
model_data = load_dp_model(model_file)
95+
self.dp = HLO(
96+
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
97+
stablehlo_atomic_virial=model_data["@variables"][
98+
"stablehlo_atomic_virial"
99+
].tobytes(),
100+
model_def_script=model_data["model_def_script"],
101+
**model_data["constants"],
102+
)
103+
elif model_file.endswith(".savedmodel"):
104+
from deepmd.jax.jax2tf.tfmodel import (
105+
TFModelWrapper,
106+
)
107+
108+
self.dp = TFModelWrapper(model_file)
109+
else:
110+
raise ValueError("Unsupported file extension")
102111
self.rcut = self.dp.get_rcut()
103112
self.type_map = self.dp.get_type_map()
104113
if isinstance(auto_batch_size, bool):

deepmd/jax/jax2tf/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import tensorflow as tf
3+
4+
if not tf.executing_eagerly():
5+
# TF disallow temporary eager execution
6+
raise RuntimeError(
7+
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the "
8+
"TensorFlow backend (disables eager execution). "
9+
"If you are converting a model between different backends, "
10+
"considering converting to the `.dp` format first."
11+
)

deepmd/jax/jax2tf/serialization.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
4+
import tensorflow as tf
5+
from jax.experimental import (
6+
jax2tf,
7+
)
8+
9+
from deepmd.jax.model.base_model import (
10+
BaseModel,
11+
)
12+
13+
14+
def deserialize_to_file(model_file: str, data: dict) -> None:
15+
"""Deserialize the dictionary to a model file.
16+
17+
Parameters
18+
----------
19+
model_file : str
20+
The model file to be saved.
21+
data : dict
22+
The dictionary to be deserialized.
23+
"""
24+
if model_file.endswith(".savedmodel"):
25+
model = BaseModel.deserialize(data["model"])
26+
model_def_script = data["model_def_script"]
27+
call_lower = model.call_lower
28+
29+
tf_model = tf.Module()
30+
31+
def exported_whether_do_atomic_virial(do_atomic_virial):
32+
def call_lower_with_fixed_do_atomic_virial(
33+
coord, atype, nlist, mapping, fparam, aparam
34+
):
35+
return call_lower(
36+
coord,
37+
atype,
38+
nlist,
39+
mapping,
40+
fparam,
41+
aparam,
42+
do_atomic_virial=do_atomic_virial,
43+
)
44+
45+
return jax2tf.convert(
46+
call_lower_with_fixed_do_atomic_virial,
47+
polymorphic_shapes=[
48+
"(nf, nloc + nghost, 3)",
49+
"(nf, nloc + nghost)",
50+
f"(nf, nloc, {model.get_nnei()})",
51+
"(nf, nloc + nghost)",
52+
f"(nf, {model.get_dim_fparam()})",
53+
f"(nf, nloc, {model.get_dim_aparam()})",
54+
],
55+
with_gradient=True,
56+
)
57+
58+
# Save a function that can take scalar inputs.
59+
# We need to explicit set the function name, so C++ can find it.
60+
@tf.function(
61+
autograph=False,
62+
input_signature=[
63+
tf.TensorSpec([None, None, 3], tf.float64),
64+
tf.TensorSpec([None, None], tf.int32),
65+
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
66+
tf.TensorSpec([None, None], tf.int64),
67+
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
68+
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
69+
],
70+
)
71+
def call_lower_without_atomic_virial(
72+
coord, atype, nlist, mapping, fparam, aparam
73+
):
74+
return exported_whether_do_atomic_virial(do_atomic_virial=False)(
75+
coord, atype, nlist, mapping, fparam, aparam
76+
)
77+
78+
tf_model.call_lower = call_lower_without_atomic_virial
79+
80+
@tf.function(
81+
autograph=False,
82+
input_signature=[
83+
tf.TensorSpec([None, None, 3], tf.float64),
84+
tf.TensorSpec([None, None], tf.int32),
85+
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
86+
tf.TensorSpec([None, None], tf.int64),
87+
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
88+
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
89+
],
90+
)
91+
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
92+
return exported_whether_do_atomic_virial(do_atomic_virial=True)(
93+
coord, atype, nlist, mapping, fparam, aparam
94+
)
95+
96+
tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial
97+
98+
# set functions to export other attributes
99+
@tf.function
100+
def get_type_map():
101+
return tf.constant(model.get_type_map(), dtype=tf.string)
102+
103+
tf_model.get_type_map = get_type_map
104+
105+
@tf.function
106+
def get_rcut():
107+
return tf.constant(model.get_rcut(), dtype=tf.double)
108+
109+
tf_model.get_rcut = get_rcut
110+
111+
@tf.function
112+
def get_dim_fparam():
113+
return tf.constant(model.get_dim_fparam(), dtype=tf.int64)
114+
115+
tf_model.get_dim_fparam = get_dim_fparam
116+
117+
@tf.function
118+
def get_dim_aparam():
119+
return tf.constant(model.get_dim_aparam(), dtype=tf.int64)
120+
121+
tf_model.get_dim_aparam = get_dim_aparam
122+
123+
@tf.function
124+
def get_sel_type():
125+
return tf.constant(model.get_sel_type(), dtype=tf.int64)
126+
127+
tf_model.get_sel_type = get_sel_type
128+
129+
@tf.function
130+
def is_aparam_nall():
131+
return tf.constant(model.is_aparam_nall(), dtype=tf.bool)
132+
133+
tf_model.is_aparam_nall = is_aparam_nall
134+
135+
@tf.function
136+
def model_output_type():
137+
return tf.constant(model.model_output_type(), dtype=tf.string)
138+
139+
tf_model.model_output_type = model_output_type
140+
141+
@tf.function
142+
def mixed_types():
143+
return tf.constant(model.mixed_types(), dtype=tf.bool)
144+
145+
tf_model.mixed_types = mixed_types
146+
147+
if model.get_min_nbor_dist() is not None:
148+
149+
@tf.function
150+
def get_min_nbor_dist():
151+
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double)
152+
153+
tf_model.get_min_nbor_dist = get_min_nbor_dist
154+
155+
@tf.function
156+
def get_sel():
157+
return tf.constant(model.get_sel(), dtype=tf.int64)
158+
159+
tf_model.get_sel = get_sel
160+
161+
@tf.function
162+
def get_model_def_script():
163+
return tf.constant(
164+
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
165+
)
166+
167+
tf_model.get_model_def_script = get_model_def_script
168+
tf.saved_model.save(
169+
tf_model,
170+
model_file,
171+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
172+
)

0 commit comments

Comments
 (0)