Skip to content

Commit 4f7af6b

Browse files
committed
update dpmd test
1 parent b7f869d commit 4f7af6b

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

dargs/dargs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,8 @@ def gen_doc(self, paths: Optional[List[str]] = None,
492492
if kwargs.get("make_link"):
493493
if not kwargs.get("make_anchor"):
494494
raise ValueError("`make_link` only works with `make_anchor` set")
495-
fnstr, target = make_ref_pair(paths+[self.flag_name], fnstr, "emph")
496-
body_list.append("\n" + target)
495+
fnstr, target = make_ref_pair(paths+[self.flag_name], fnstr, "flag")
496+
body_list.append(target + "\n")
497497
for choice in self.choice_dict.values():
498498
body_list.append("")
499499
choice_path = self._make_cpath(choice.name, paths, showflag)
@@ -526,7 +526,7 @@ def gen_doc_flag(self, paths: Optional[List[str]] = None, **kwargs) -> str:
526526
self._make_cpath(c.name, paths, kwargs["showflag"]),
527527
text=f"``{c.name}``", prefix="code")
528528
for c in self.choice_dict.values()))
529-
targetdoc = indent('\n' + '\n'.join(l_target), INDENT)
529+
targetdoc = indent('\n'.join(l_target) + "\n", INDENT)
530530
else:
531531
l_choice = [c.name for c in self.choice_dict.values()]
532532
targetdoc = None

tests/dpmdargs.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,13 @@ def descrpt_variant_type_args():
181181
link_se_a_3be = make_link('se_a_3be', 'model/descriptor[se_a_3be]')
182182
link_se_a_tpe = make_link('se_a_tpe', 'model/descriptor[se_a_tpe]')
183183
link_hybrid = make_link('hybrid', 'model/descriptor[hybrid]')
184-
doc_descrpt_type = f'The type of the descritpor. Valid types are {link_lf}, {link_se_a}, {link_se_r}, {link_se_a_3be}, {link_se_a_tpe}, {link_hybrid}. \n\n\
184+
doc_descrpt_type = f'The type of the descritpor. See explanation below. \n\n\
185185
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
186186
- `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
187187
- `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
188188
- `se_a_3be`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\
189189
- `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\
190-
- `hybrid`: Concatenate of a list of descriptors as a new descriptor.\n\n\
191-
- `se_ar`: A hybrid of `se_a` and `se_r`. Typically `se_a` has a smaller cut-off while the `se_r` has a larger cut-off. Deprecated, use `hybrid` instead.'
190+
- `hybrid`: Concatenate of a list of descriptors as a new descriptor.'
192191

193192
return Variant("type", [
194193
Argument("loc_frame", dict, descrpt_local_frame_args()),
@@ -197,7 +196,6 @@ def descrpt_variant_type_args():
197196
Argument("se_a_3be", dict, descrpt_se_a_3be_args(), alias = ['se_at']),
198197
Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias = ['se_a_ebd']),
199198
Argument("hybrid", dict, descrpt_hybrid_args()),
200-
Argument("se_ar", dict, descrpt_se_ar_args()),
201199
], doc = doc_descrpt_type)
202200

203201

@@ -275,7 +273,7 @@ def fitting_dipole():
275273

276274

277275
def fitting_variant_type_args():
278-
doc_descrpt_type = 'The type of the fitting. Valid types are `ener`, `dipole`, `polar` and `global_polar`. \n\n\
276+
doc_descrpt_type = 'The type of the fitting. See explanation below. \n\n\
279277
- `ener`: Fit an energy model (potential energy surface).\n\n\
280278
- `dipole`: Fit an atomic dipole model. Atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file has number of frames lines and 3 times of number of selected atoms columns.\n\n\
281279
- `polar`: Fit an atomic polarizability model. Atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 times of number of selected atoms columns.\n\n\
@@ -289,13 +287,39 @@ def fitting_variant_type_args():
289287
default_tag = 'ener',
290288
doc = doc_descrpt_type)
291289

290+
def modifier_dipole_charge():
291+
doc_model_name = "The name of the frozen dipole model file."
292+
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model/fitting_net[dipole]/sel_type')}. "
293+
doc_sys_charge_map = f"The charge of real atoms. The list length should be the same as the {make_link('type_map', 'model/type_map')}"
294+
doc_ewald_h = f"The grid spacing of the FFT grid. Unit is A"
295+
doc_ewald_beta = f"The splitting parameter of Ewald sum. Unit is A^{-1}"
296+
297+
return [
298+
Argument("model_name", str, optional = False, doc = doc_model_name),
299+
Argument("model_charge_map", list, optional = False, doc = doc_model_charge_map),
300+
Argument("sys_charge_map", list, optional = False, doc = doc_sys_charge_map),
301+
Argument("ewald_beta", float, optional = True, default = 0.4, doc = doc_ewald_beta),
302+
Argument("ewald_h", float, optional = True, default = 1.0, doc = doc_ewald_h),
303+
]
304+
305+
def modifier_variant_type_args():
306+
doc_modifier_type = "The type of modifier. See explanation below.\n\n\
307+
-`dipole_charge`: Use WFCC to model the electronic structure of the system. Correct the long-range interaction"
308+
return Variant("type",
309+
[
310+
Argument("dipole_charge", dict, modifier_dipole_charge()),
311+
],
312+
optional = False,
313+
doc = doc_modifier_type)
314+
292315

293316
def model_args ():
294317
doc_type_map = 'A list of strings. Give the name to each type of atoms.'
295318
doc_data_stat_nbatch = 'The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics.'
296319
doc_data_stat_protect = 'Protect parameter for atomic energy regression.'
297320
doc_descrpt = 'The descriptor of atomic environment.'
298321
doc_fitting = 'The fitting of physical properties.'
322+
doc_modifier = 'The modifier of model output.'
299323
doc_use_srtab = 'The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.'
300324
doc_smin_alpha = 'The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided.'
301325
doc_sw_rmin = 'The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.'
@@ -310,7 +334,8 @@ def model_args ():
310334
Argument("sw_rmin", float, optional = True, doc = doc_sw_rmin),
311335
Argument("sw_rmax", float, optional = True, doc = doc_sw_rmax),
312336
Argument("descriptor", dict, [], [descrpt_variant_type_args()], doc = doc_descrpt),
313-
Argument("fitting_net", dict, [], [fitting_variant_type_args()], doc = doc_fitting)
337+
Argument("fitting_net", dict, [], [fitting_variant_type_args()], doc = doc_fitting),
338+
Argument("modifier", dict, [], [modifier_variant_type_args()], optional = True, doc = doc_modifier),
314339
])
315340
# print(ca.gen_doc())
316341
return ca
@@ -330,7 +355,7 @@ def learning_rate_exp():
330355

331356

332357
def learning_rate_variant_type_args():
333-
doc_lr = 'The type of the learning rate. Current type `exp`, the exponentially decaying learning rate is supported.'
358+
doc_lr = 'The type of the learning rate.'
334359

335360
return Variant("type",
336361
[Argument("exp", dict, learning_rate_exp())],
@@ -376,7 +401,7 @@ def loss_ener():
376401

377402

378403
def loss_variant_type_args():
379-
doc_loss = 'The type of the loss. For fitting type `ener`, the loss type should be set to `ener` or left unset. For tensorial fitting types `dipole`, `polar` and `global_polar`, the type should be left unset.\n\.'
404+
doc_loss = 'The type of the loss. \n\.'
380405

381406
return Variant("type",
382407
[Argument("ener", dict, loss_ener())],
@@ -452,16 +477,18 @@ def make_index(keys):
452477
return ', '.join(ret)
453478

454479

455-
def gen_doc(**kwargs):
480+
def gen_doc(*, make_anchor=True, make_link=True, **kwargs):
481+
if make_link:
482+
make_anchor = True
456483
ma = model_args()
457484
lra = learning_rate_args()
458485
la = loss_args()
459486
ta = training_args()
460487
ptr = []
461-
ptr.append(ma.gen_doc(**kwargs))
462-
ptr.append(la.gen_doc(**kwargs))
463-
ptr.append(lra.gen_doc(**kwargs))
464-
ptr.append(ta.gen_doc(**kwargs))
488+
ptr.append(ma.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
489+
ptr.append(la.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
490+
ptr.append(lra.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
491+
ptr.append(ta.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
465492

466493
key_words = []
467494
for ii in "\n\n".join(ptr).split('\n'):

0 commit comments

Comments
 (0)