Skip to content

Commit 4064b3b

Browse files
Copilotnjzjz
andcommitted
feat(jax): add comprehensive type hints to jax2tf interop code
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 668043d commit 4064b3b

File tree

8 files changed

+640
-542
lines changed

8 files changed

+640
-542
lines changed

deepmd/jax/jax2tf/format_nlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def format_nlist(
99
nlist: tnp.ndarray,
1010
nsel: int,
1111
rcut: float,
12-
):
12+
) -> tnp.ndarray:
1313
"""Format neighbor list.
1414
1515
If nnei == nsel, do nothing;

deepmd/jax/jax2tf/make_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def model_call_from_call_lower(
4444
fparam: tnp.ndarray,
4545
aparam: tnp.ndarray,
4646
do_atomic_virial: bool = False,
47-
):
47+
) -> dict[str, tnp.ndarray]:
4848
"""Return model prediction from lower interface.
4949
5050
Parameters

deepmd/jax/jax2tf/nlist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def nlist_distinguish_types(
115115
nlist: tnp.ndarray,
116116
atype: tnp.ndarray,
117117
sel: list[int],
118-
):
118+
) -> tnp.ndarray:
119119
"""Given a nlist that does not distinguish atom types, return a nlist that
120120
distinguish atom types.
121121
@@ -140,7 +140,7 @@ def nlist_distinguish_types(
140140
return ret
141141

142142

143-
def tf_outer(a, b):
143+
def tf_outer(a: tnp.ndarray, b: tnp.ndarray) -> tnp.ndarray:
144144
return tf.einsum("i,j->ij", a, b)
145145

146146

@@ -150,7 +150,7 @@ def extend_coord_with_ghosts(
150150
atype: tnp.ndarray,
151151
cell: tnp.ndarray,
152152
rcut: float,
153-
):
153+
) -> tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray]:
154154
"""Extend the coordinates of the atoms by appending peridoc images.
155155
The number of images is large enough to ensure all the neighbors
156156
within rcut are appended.

deepmd/jax/jax2tf/region.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def to_face_distance(
9393
return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0))
9494

9595

96-
def b_to_face_distance(cell):
96+
def b_to_face_distance(cell: tnp.ndarray) -> tnp.ndarray:
9797
volume = tf.linalg.det(cell)
9898
c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...])
9999
_h2yz = volume / tf.linalg.norm(c_yz, axis=-1)

deepmd/jax/jax2tf/serialization.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
33
from typing import (
4+
Callable,
45
Optional,
56
)
67

@@ -38,10 +39,17 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
3839

3940
tf_model = tf.Module()
4041

41-
def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms):
42+
def exported_whether_do_atomic_virial(
43+
do_atomic_virial: bool, has_ghost_atoms: bool
44+
) -> Callable:
4245
def call_lower_with_fixed_do_atomic_virial(
43-
coord, atype, nlist, mapping, fparam, aparam
44-
):
46+
coord: tnp.ndarray,
47+
atype: tnp.ndarray,
48+
nlist: tnp.ndarray,
49+
mapping: tnp.ndarray,
50+
fparam: tnp.ndarray,
51+
aparam: tnp.ndarray,
52+
) -> dict[str, tnp.ndarray]:
4553
return call_lower(
4654
coord,
4755
atype,
@@ -86,8 +94,13 @@ def call_lower_with_fixed_do_atomic_virial(
8694
],
8795
)
8896
def call_lower_without_atomic_virial(
89-
coord, atype, nlist, mapping, fparam, aparam
90-
):
97+
coord: tnp.ndarray,
98+
atype: tnp.ndarray,
99+
nlist: tnp.ndarray,
100+
mapping: tnp.ndarray,
101+
fparam: tnp.ndarray,
102+
aparam: tnp.ndarray,
103+
) -> dict[str, tnp.ndarray]:
91104
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
92105
return tf.cond(
93106
tf.shape(coord)[1] == tf.shape(nlist)[1],
@@ -112,7 +125,14 @@ def call_lower_without_atomic_virial(
112125
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
113126
],
114127
)
115-
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
128+
def call_lower_with_atomic_virial(
129+
coord: tnp.ndarray,
130+
atype: tnp.ndarray,
131+
nlist: tnp.ndarray,
132+
mapping: tnp.ndarray,
133+
fparam: tnp.ndarray,
134+
aparam: tnp.ndarray,
135+
) -> dict[str, tnp.ndarray]:
116136
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
117137
return tf.cond(
118138
tf.shape(coord)[1] == tf.shape(nlist)[1],
@@ -126,7 +146,7 @@ def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
126146

127147
tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial
128148

129-
def make_call_whether_do_atomic_virial(do_atomic_virial: bool):
149+
def make_call_whether_do_atomic_virial(do_atomic_virial: bool) -> Callable:
130150
if do_atomic_virial:
131151
call_lower = call_lower_with_atomic_virial
132152
else:
@@ -138,7 +158,7 @@ def call(
138158
box: Optional[tnp.ndarray] = None,
139159
fparam: Optional[tnp.ndarray] = None,
140160
aparam: Optional[tnp.ndarray] = None,
141-
):
161+
) -> dict[str, tnp.ndarray]:
142162
"""Return model prediction.
143163
144164
Parameters
@@ -194,7 +214,7 @@ def call_with_atomic_virial(
194214
box: tnp.ndarray,
195215
fparam: tnp.ndarray,
196216
aparam: tnp.ndarray,
197-
):
217+
) -> dict[str, tnp.ndarray]:
198218
return make_call_whether_do_atomic_virial(do_atomic_virial=True)(
199219
coord, atype, box, fparam, aparam
200220
)
@@ -217,7 +237,7 @@ def call_without_atomic_virial(
217237
box: tnp.ndarray,
218238
fparam: tnp.ndarray,
219239
aparam: tnp.ndarray,
220-
):
240+
) -> dict[str, tnp.ndarray]:
221241
return make_call_whether_do_atomic_virial(do_atomic_virial=False)(
222242
coord, atype, box, fparam, aparam
223243
)
@@ -226,69 +246,69 @@ def call_without_atomic_virial(
226246

227247
# set functions to export other attributes
228248
@tf.function
229-
def get_type_map():
249+
def get_type_map() -> tf.Tensor:
230250
return tf.constant(model.get_type_map(), dtype=tf.string)
231251

232252
tf_model.get_type_map = get_type_map
233253

234254
@tf.function
235-
def get_rcut():
255+
def get_rcut() -> tf.Tensor:
236256
return tf.constant(model.get_rcut(), dtype=tf.double)
237257

238258
tf_model.get_rcut = get_rcut
239259

240260
@tf.function
241-
def get_dim_fparam():
261+
def get_dim_fparam() -> tf.Tensor:
242262
return tf.constant(model.get_dim_fparam(), dtype=tf.int64)
243263

244264
tf_model.get_dim_fparam = get_dim_fparam
245265

246266
@tf.function
247-
def get_dim_aparam():
267+
def get_dim_aparam() -> tf.Tensor:
248268
return tf.constant(model.get_dim_aparam(), dtype=tf.int64)
249269

250270
tf_model.get_dim_aparam = get_dim_aparam
251271

252272
@tf.function
253-
def get_sel_type():
273+
def get_sel_type() -> tf.Tensor:
254274
return tf.constant(model.get_sel_type(), dtype=tf.int64)
255275

256276
tf_model.get_sel_type = get_sel_type
257277

258278
@tf.function
259-
def is_aparam_nall():
279+
def is_aparam_nall() -> tf.Tensor:
260280
return tf.constant(model.is_aparam_nall(), dtype=tf.bool)
261281

262282
tf_model.is_aparam_nall = is_aparam_nall
263283

264284
@tf.function
265-
def model_output_type():
285+
def model_output_type() -> tf.Tensor:
266286
return tf.constant(model.model_output_type(), dtype=tf.string)
267287

268288
tf_model.model_output_type = model_output_type
269289

270290
@tf.function
271-
def mixed_types():
291+
def mixed_types() -> tf.Tensor:
272292
return tf.constant(model.mixed_types(), dtype=tf.bool)
273293

274294
tf_model.mixed_types = mixed_types
275295

276296
if model.get_min_nbor_dist() is not None:
277297

278298
@tf.function
279-
def get_min_nbor_dist():
299+
def get_min_nbor_dist() -> tf.Tensor:
280300
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double)
281301

282302
tf_model.get_min_nbor_dist = get_min_nbor_dist
283303

284304
@tf.function
285-
def get_sel():
305+
def get_sel() -> tf.Tensor:
286306
return tf.constant(model.get_sel(), dtype=tf.int64)
287307

288308
tf_model.get_sel = get_sel
289309

290310
@tf.function
291-
def get_model_def_script():
311+
def get_model_def_script() -> tf.Tensor:
292312
return tf.constant(
293313
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
294314
)

deepmd/jax/jax2tf/tfmodel.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]:
4545
class TFModelWrapper(tf.Module):
4646
def __init__(
4747
self,
48-
model,
48+
model: str,
4949
) -> None:
5050
self.model = tf.saved_model.load(model)
5151
self._call_lower = jax2tf.call_tf(self.model.call_lower)
@@ -115,7 +115,7 @@ def call(
115115
fparam: Optional[jnp.ndarray] = None,
116116
aparam: Optional[jnp.ndarray] = None,
117117
do_atomic_virial: bool = False,
118-
):
118+
) -> Any:
119119
"""Return model prediction.
120120
121121
Parameters
@@ -165,7 +165,7 @@ def call(
165165
aparam,
166166
)
167167

168-
def model_output_def(self):
168+
def model_output_def(self) -> ModelOutputDef:
169169
return ModelOutputDef(
170170
FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()])
171171
)
@@ -179,7 +179,7 @@ def call_lower(
179179
fparam: Optional[jnp.ndarray] = None,
180180
aparam: Optional[jnp.ndarray] = None,
181181
do_atomic_virial: bool = False,
182-
):
182+
) -> Any:
183183
if do_atomic_virial:
184184
call_lower = self._call_lower_atomic_virial
185185
else:
@@ -207,15 +207,15 @@ def get_type_map(self) -> list[str]:
207207
"""Get the type map."""
208208
return self.type_map
209209

210-
def get_rcut(self):
210+
def get_rcut(self) -> float:
211211
"""Get the cut-off radius."""
212212
return self.rcut
213213

214-
def get_dim_fparam(self):
214+
def get_dim_fparam(self) -> int:
215215
"""Get the number (dimension) of frame parameters of this atomic model."""
216216
return self.dim_fparam
217217

218-
def get_dim_aparam(self):
218+
def get_dim_aparam(self) -> int:
219219
"""Get the number (dimension) of atomic parameters of this atomic model."""
220220
return self.dim_aparam
221221

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ runtime-evaluated-base-classes = ["torch.nn.Module"]
425425
"deepmd/tf/**" = ["TID253", "ANN"]
426426
"deepmd/pt/**" = ["TID253"]
427427
"deepmd/jax/**" = ["TID253"]
428+
"deepmd/jax/jax2tf/**" = ["TID253", "ANN"]
428429
"deepmd/pd/**" = ["TID253", "ANN"]
429430

430431
"source/**" = ["ANN"]

0 commit comments

Comments
 (0)