Skip to content

Commit ff04d8b

Browse files
authored
fix(dpmodel/jax): fix fparam and aparam support in DeepEval (#4285)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced error messages for improved clarity when input dimensions are incorrect. - Added support for optional fitting and atomic parameters in model evaluations. - **Bug Fixes** - Removed restrictions on providing fitting and atomic parameters, allowing for more flexible evaluations. - **Tests** - Introduced a new test class to validate the handling of fitting and atomic parameters in model evaluations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 9c767ad commit ff04d8b

File tree

5 files changed

+93
-16
lines changed

5 files changed

+93
-16
lines changed

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def _call_common(
388388
assert fparam is not None, "fparam should not be None"
389389
if fparam.shape[-1] != self.numb_fparam:
390390
raise ValueError(
391-
"get an input fparam of dim {fparam.shape[-1]}, ",
392-
"which is not consistent with {self.numb_fparam}.",
391+
f"get an input fparam of dim {fparam.shape[-1]}, "
392+
f"which is not consistent with {self.numb_fparam}."
393393
)
394394
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
395395
fparam = xp.tile(
@@ -409,8 +409,8 @@ def _call_common(
409409
assert aparam is not None, "aparam should not be None"
410410
if aparam.shape[-1] != self.numb_aparam:
411411
raise ValueError(
412-
"get an input aparam of dim {aparam.shape[-1]}, ",
413-
"which is not consistent with {self.numb_aparam}.",
412+
f"get an input aparam of dim {aparam.shape[-1]}, "
413+
f"which is not consistent with {self.numb_aparam}."
414414
)
415415
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
416416
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,6 @@ def eval(
204204
The output of the evaluation. The keys are the names of the output
205205
variables, and the values are the corresponding output arrays.
206206
"""
207-
if fparam is not None or aparam is not None:
208-
raise NotImplementedError
209207
# convert all of the input to numpy array
210208
atom_types = np.array(atom_types, dtype=np.int32)
211209
coords = np.array(coords)
@@ -216,7 +214,7 @@ def eval(
216214
)
217215
request_defs = self._get_request_defs(atomic)
218216
out = self._eval_func(self._eval_model, numb_test, natoms)(
219-
coords, cells, atom_types, request_defs
217+
coords, cells, atom_types, fparam, aparam, request_defs
220218
)
221219
return dict(
222220
zip(
@@ -306,6 +304,8 @@ def _eval_model(
306304
coords: np.ndarray,
307305
cells: Optional[np.ndarray],
308306
atom_types: np.ndarray,
307+
fparam: Optional[np.ndarray],
308+
aparam: Optional[np.ndarray],
309309
request_defs: list[OutputVariableDef],
310310
):
311311
model = self.dp
@@ -323,12 +323,25 @@ def _eval_model(
323323
box_input = cells.reshape([-1, 3, 3])
324324
else:
325325
box_input = None
326+
if fparam is not None:
327+
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
328+
else:
329+
fparam_input = None
330+
if aparam is not None:
331+
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
332+
else:
333+
aparam_input = None
326334

327335
do_atomic_virial = any(
328336
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
329337
)
330338
batch_output = model(
331-
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
339+
coord_input,
340+
type_input,
341+
box=box_input,
342+
fparam=fparam_input,
343+
aparam=aparam_input,
344+
do_atomic_virial=do_atomic_virial,
332345
)
333346
if isinstance(batch_output, tuple):
334347
batch_output = batch_output[0]

deepmd/jax/infer/deep_eval.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def eval(
214214
The output of the evaluation. The keys are the names of the output
215215
variables, and the values are the corresponding output arrays.
216216
"""
217-
if fparam is not None or aparam is not None:
218-
raise NotImplementedError
219217
# convert all of the input to numpy array
220218
atom_types = np.array(atom_types, dtype=np.int32)
221219
coords = np.array(coords)
@@ -226,7 +224,7 @@ def eval(
226224
)
227225
request_defs = self._get_request_defs(atomic)
228226
out = self._eval_func(self._eval_model, numb_test, natoms)(
229-
coords, cells, atom_types, request_defs
227+
coords, cells, atom_types, fparam, aparam, request_defs
230228
)
231229
return dict(
232230
zip(
@@ -316,6 +314,8 @@ def _eval_model(
316314
coords: np.ndarray,
317315
cells: Optional[np.ndarray],
318316
atom_types: np.ndarray,
317+
fparam: Optional[np.ndarray],
318+
aparam: Optional[np.ndarray],
319319
request_defs: list[OutputVariableDef],
320320
):
321321
model = self.dp
@@ -333,6 +333,14 @@ def _eval_model(
333333
box_input = cells.reshape([-1, 3, 3])
334334
else:
335335
box_input = None
336+
if fparam is not None:
337+
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
338+
else:
339+
fparam_input = None
340+
if aparam is not None:
341+
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
342+
else:
343+
aparam_input = None
336344

337345
do_atomic_virial = any(
338346
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
@@ -341,6 +349,8 @@ def _eval_model(
341349
to_jax_array(coord_input),
342350
to_jax_array(type_input),
343351
box=to_jax_array(box_input),
352+
fparam=to_jax_array(fparam_input),
353+
aparam=to_jax_array(aparam_input),
344354
do_atomic_virial=do_atomic_virial,
345355
)
346356
if isinstance(batch_output, tuple):

deepmd/jax/utils/serialization.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
5151
model_def_script = data["model_def_script"]
5252
call_lower = model.call_lower
5353

54-
nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape(
55-
"nf, nloc, nghost, nfp, nap"
56-
)
54+
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
5755
exported = jax_export.export(jax.jit(call_lower))(
5856
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
5957
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
6058
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
6159
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
62-
jax.ShapeDtypeStruct((nf, nfp), jnp.float64)
60+
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
6361
if model.get_dim_fparam()
6462
else None, # fparam
65-
jax.ShapeDtypeStruct((nf, nap), jnp.float64)
63+
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
6664
if model.get_dim_aparam()
6765
else None, # aparam
6866
False, # do_atomic_virial

source/tests/consistent/io/test_io.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def test_deep_eval(self):
136136
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
137137
dtype=GLOBAL_NP_FLOAT_PRECISION,
138138
).reshape(1, 9)
139+
natoms = self.atype.shape[1]
140+
nframes = self.atype.shape[0]
139141
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
140142
rets = []
141143
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
@@ -145,10 +147,20 @@ def test_deep_eval(self):
145147
reference_data = copy.deepcopy(self.data)
146148
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
147149
deep_eval = DeepEval(prefix + backend.suffixes[0])
150+
if deep_eval.get_dim_fparam() > 0:
151+
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
152+
else:
153+
fparam = None
154+
if deep_eval.get_dim_aparam() > 0:
155+
aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam()))
156+
else:
157+
aparam = None
148158
ret = deep_eval.eval(
149159
self.coords,
150160
self.box,
151161
self.atype,
162+
fparam=fparam,
163+
aparam=aparam,
152164
)
153165
rets.append(ret)
154166
for ret in rets[1:]:
@@ -199,3 +211,47 @@ def setUp(self):
199211

200212
def tearDown(self):
201213
IOTest.tearDown(self)
214+
215+
216+
class TestDeepPotFparamAparam(unittest.TestCase, IOTest):
217+
def setUp(self):
218+
model_def_script = {
219+
"type_map": ["O", "H"],
220+
"descriptor": {
221+
"type": "se_e2_a",
222+
"sel": [20, 20],
223+
"rcut_smth": 0.50,
224+
"rcut": 6.00,
225+
"neuron": [
226+
3,
227+
6,
228+
],
229+
"resnet_dt": False,
230+
"axis_neuron": 2,
231+
"precision": "float64",
232+
"type_one_side": True,
233+
"seed": 1,
234+
},
235+
"fitting_net": {
236+
"type": "ener",
237+
"neuron": [
238+
5,
239+
5,
240+
],
241+
"resnet_dt": True,
242+
"precision": "float64",
243+
"atom_ener": [],
244+
"seed": 1,
245+
"numb_fparam": 2,
246+
"numb_aparam": 2,
247+
},
248+
}
249+
model = get_model(copy.deepcopy(model_def_script))
250+
self.data = {
251+
"model": model.serialize(),
252+
"backend": "test",
253+
"model_def_script": model_def_script,
254+
}
255+
256+
def tearDown(self):
257+
IOTest.tearDown(self)

0 commit comments

Comments
 (0)