11# SPDX-License-Identifier: LGPL-3.0-or-later
22import json
33from typing import (
4+ Callable ,
45 Optional ,
56)
67
@@ -38,10 +39,17 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
3839
3940 tf_model = tf .Module ()
4041
41- def exported_whether_do_atomic_virial (do_atomic_virial , has_ghost_atoms ):
42+ def exported_whether_do_atomic_virial (
43+ do_atomic_virial : bool , has_ghost_atoms : bool
44+ ) -> Callable :
4245 def call_lower_with_fixed_do_atomic_virial (
43- coord , atype , nlist , mapping , fparam , aparam
44- ):
46+ coord : tnp .ndarray ,
47+ atype : tnp .ndarray ,
48+ nlist : tnp .ndarray ,
49+ mapping : tnp .ndarray ,
50+ fparam : tnp .ndarray ,
51+ aparam : tnp .ndarray ,
52+ ) -> dict [str , tnp .ndarray ]:
4553 return call_lower (
4654 coord ,
4755 atype ,
@@ -86,8 +94,13 @@ def call_lower_with_fixed_do_atomic_virial(
8694 ],
8795 )
8896 def call_lower_without_atomic_virial (
89- coord , atype , nlist , mapping , fparam , aparam
90- ):
97+ coord : tnp .ndarray ,
98+ atype : tnp .ndarray ,
99+ nlist : tnp .ndarray ,
100+ mapping : tnp .ndarray ,
101+ fparam : tnp .ndarray ,
102+ aparam : tnp .ndarray ,
103+ ) -> dict [str , tnp .ndarray ]:
91104 nlist = format_nlist (coord , nlist , model .get_nnei (), model .get_rcut ())
92105 return tf .cond (
93106 tf .shape (coord )[1 ] == tf .shape (nlist )[1 ],
@@ -112,7 +125,14 @@ def call_lower_without_atomic_virial(
112125 tf .TensorSpec ([None , None , model .get_dim_aparam ()], tf .float64 ),
113126 ],
114127 )
115- def call_lower_with_atomic_virial (coord , atype , nlist , mapping , fparam , aparam ):
128+ def call_lower_with_atomic_virial (
129+ coord : tnp .ndarray ,
130+ atype : tnp .ndarray ,
131+ nlist : tnp .ndarray ,
132+ mapping : tnp .ndarray ,
133+ fparam : tnp .ndarray ,
134+ aparam : tnp .ndarray ,
135+ ) -> dict [str , tnp .ndarray ]:
116136 nlist = format_nlist (coord , nlist , model .get_nnei (), model .get_rcut ())
117137 return tf .cond (
118138 tf .shape (coord )[1 ] == tf .shape (nlist )[1 ],
@@ -126,7 +146,7 @@ def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
126146
127147 tf_model .call_lower_atomic_virial = call_lower_with_atomic_virial
128148
129- def make_call_whether_do_atomic_virial (do_atomic_virial : bool ):
149+ def make_call_whether_do_atomic_virial (do_atomic_virial : bool ) -> Callable :
130150 if do_atomic_virial :
131151 call_lower = call_lower_with_atomic_virial
132152 else :
@@ -138,7 +158,7 @@ def call(
138158 box : Optional [tnp .ndarray ] = None ,
139159 fparam : Optional [tnp .ndarray ] = None ,
140160 aparam : Optional [tnp .ndarray ] = None ,
141- ):
161+ ) -> dict [ str , tnp . ndarray ] :
142162 """Return model prediction.
143163
144164 Parameters
@@ -194,7 +214,7 @@ def call_with_atomic_virial(
194214 box : tnp .ndarray ,
195215 fparam : tnp .ndarray ,
196216 aparam : tnp .ndarray ,
197- ):
217+ ) -> dict [ str , tnp . ndarray ] :
198218 return make_call_whether_do_atomic_virial (do_atomic_virial = True )(
199219 coord , atype , box , fparam , aparam
200220 )
@@ -217,7 +237,7 @@ def call_without_atomic_virial(
217237 box : tnp .ndarray ,
218238 fparam : tnp .ndarray ,
219239 aparam : tnp .ndarray ,
220- ):
240+ ) -> dict [ str , tnp . ndarray ] :
221241 return make_call_whether_do_atomic_virial (do_atomic_virial = False )(
222242 coord , atype , box , fparam , aparam
223243 )
@@ -226,69 +246,69 @@ def call_without_atomic_virial(
226246
227247 # set functions to export other attributes
228248 @tf .function
229- def get_type_map ():
249+ def get_type_map () -> tf . Tensor :
230250 return tf .constant (model .get_type_map (), dtype = tf .string )
231251
232252 tf_model .get_type_map = get_type_map
233253
234254 @tf .function
235- def get_rcut ():
255+ def get_rcut () -> tf . Tensor :
236256 return tf .constant (model .get_rcut (), dtype = tf .double )
237257
238258 tf_model .get_rcut = get_rcut
239259
240260 @tf .function
241- def get_dim_fparam ():
261+ def get_dim_fparam () -> tf . Tensor :
242262 return tf .constant (model .get_dim_fparam (), dtype = tf .int64 )
243263
244264 tf_model .get_dim_fparam = get_dim_fparam
245265
246266 @tf .function
247- def get_dim_aparam ():
267+ def get_dim_aparam () -> tf . Tensor :
248268 return tf .constant (model .get_dim_aparam (), dtype = tf .int64 )
249269
250270 tf_model .get_dim_aparam = get_dim_aparam
251271
252272 @tf .function
253- def get_sel_type ():
273+ def get_sel_type () -> tf . Tensor :
254274 return tf .constant (model .get_sel_type (), dtype = tf .int64 )
255275
256276 tf_model .get_sel_type = get_sel_type
257277
258278 @tf .function
259- def is_aparam_nall ():
279+ def is_aparam_nall () -> tf . Tensor :
260280 return tf .constant (model .is_aparam_nall (), dtype = tf .bool )
261281
262282 tf_model .is_aparam_nall = is_aparam_nall
263283
264284 @tf .function
265- def model_output_type ():
285+ def model_output_type () -> tf . Tensor :
266286 return tf .constant (model .model_output_type (), dtype = tf .string )
267287
268288 tf_model .model_output_type = model_output_type
269289
270290 @tf .function
271- def mixed_types ():
291+ def mixed_types () -> tf . Tensor :
272292 return tf .constant (model .mixed_types (), dtype = tf .bool )
273293
274294 tf_model .mixed_types = mixed_types
275295
276296 if model .get_min_nbor_dist () is not None :
277297
278298 @tf .function
279- def get_min_nbor_dist ():
299+ def get_min_nbor_dist () -> tf . Tensor :
280300 return tf .constant (model .get_min_nbor_dist (), dtype = tf .double )
281301
282302 tf_model .get_min_nbor_dist = get_min_nbor_dist
283303
284304 @tf .function
285- def get_sel ():
305+ def get_sel () -> tf . Tensor :
286306 return tf .constant (model .get_sel (), dtype = tf .int64 )
287307
288308 tf_model .get_sel = get_sel
289309
290310 @tf .function
291- def get_model_def_script ():
311+ def get_model_def_script () -> tf . Tensor :
292312 return tf .constant (
293313 json .dumps (model_def_script , separators = ("," , ":" )), dtype = tf .string
294314 )
0 commit comments