Skip to content

Commit 887c92f

Browse files
committed
update dpmd test
1 parent ddef847 commit 887c92f

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

tests/dpmdargs.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def model_args ():
316316
return ca
317317

318318

319-
def learning_rate_args():
319+
def learning_rate_exp():
320320
doc_start_lr = 'The learning rate the start of the training.'
321321
doc_stop_lr = 'The desired learning rate at the end of the training.'
322322
doc_decay_steps = 'The learning rate is decaying every this number of training steps.'
@@ -326,9 +326,24 @@ def learning_rate_args():
326326
Argument("stop_lr", float, optional = True, default = 1e-8, doc = doc_stop_lr),
327327
Argument("decay_steps", int, optional = True, default = 5000, doc = doc_decay_steps)
328328
]
329+
return args
330+
331+
332+
def learning_rate_variant_type_args():
333+
doc_lr = 'The type of the learning rate. Current type `exp`, the exponentially decaying learning rate is supported.'
334+
335+
return Variant("type",
336+
[Argument("exp", dict, learning_rate_exp())],
337+
optional = True,
338+
default_tag = 'exp',
339+
doc = doc_lr)
340+
329341

330-
doc_lr = "The learning rate options"
331-
return Argument("learning_rate", dict, args, [], doc = doc_lr)
342+
def learning_rate_args():
343+
doc_lr = "The definitio of learning rate"
344+
return Argument("learning_rate", dict, [],
345+
[learning_rate_variant_type_args()],
346+
doc = doc_lr)
332347

333348

334349
def start_pref(item):
@@ -378,15 +393,16 @@ def loss_args():
378393
return ca
379394

380395
def training_args():
396+
link_sys = make_link("systems", "training/systems")
381397
doc_systems = 'The data systems. This key can be provided with a listthat specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.'
382-
doc_set_prefix = 'The prefix of the sets in the systems.'
398+
doc_set_prefix = f'The prefix of the sets in the {link_sys}.'
383399
doc_stop_batch = 'Number of training batch. Each training uses one batch of data.'
384-
doc_batch_size = 'This key can be \n\n\
385-
- list: the length of which is the same as the `systems`. The batch size of each system is given by the elements of the list.\n\n\
386-
- int: all `systems` uses the same batch size.\n\n\
387-
- string "auto": automatically determines the batch size os that the batch_size times the number of atoms in the system is no less than 32.\n\n\
388-
- string "auto:N": automatically determines the batch size os that the batch_size times the number of atoms in the system is no less than N.'
389-
doc_seed = 'The random seed for training.'
400+
doc_batch_size = f'This key can be \n\n\
401+
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
402+
- int: all {link_sys} use the same batch size.\n\n\
403+
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
404+
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
405+
doc_seed = 'The random seed for getting frames from the training data set.'
390406
doc_disp_file = 'The file for printing learning curve.'
391407
doc_disp_freq = 'The frequency of printing learning curve.'
392408
doc_numb_test = 'Number of frames used for the test during training.'
@@ -396,12 +412,21 @@ def training_args():
396412
doc_time_training = 'Timing durining training.'
397413
doc_profiling = 'Profiling during training.'
398414
doc_profiling_file = 'Output file for profiling.'
415+
doc_train_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
416+
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
417+
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\
418+
- "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." : the list of systems is devided into blocks. A block is specified by `stt_idx:end_idx:weight`, where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system, the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional to the number of batches in the system.'
419+
doc_train_sys_probs = "A list of float, should be of the same length as `train_systems`, specifying the probability of each system."
420+
doc_tensorboard = 'Enable tensorboard'
421+
doc_tensorboard_log_dir = 'The log directory of tensorboard outputs'
399422

400423
args = [
401-
Argument("systems", [list,str], optional = False, doc = doc_systems),
424+
Argument("systems", [list,str], optional = False, doc = doc_systems, alias = ["trn_systems"]),
402425
Argument("set_prefix", str, optional = True, default = 'set', doc = doc_set_prefix),
403-
Argument("stop_batch", int, optional = False, doc = doc_stop_batch),
404-
Argument("batch_size", [list,int,str], optional = True, default = 'auto', doc = doc_batch_size),
426+
Argument("auto_prob", str, optional = True, default = "prob_sys_size", doc = doc_train_auto_prob_style, alias = ["trn_auto_prob", "auto_prob_style"]),
427+
Argument("sys_probs", list, optional = True, default = None, doc = doc_train_sys_probs, alias = ["trn_sys_probs"]),
428+
Argument("batch_size", [list,int,str], optional = True, default = 'auto', doc = doc_batch_size, alias = ["trn_batch_size"]),
429+
Argument("numb_steps", int, optional = False, doc = doc_stop_batch, alias = ["stop_batch"]),
405430
Argument("seed", [int,None], optional = True, doc = doc_seed),
406431
Argument("disp_file", str, optional = True, default = 'lcueve.out', doc = doc_disp_file),
407432
Argument("disp_freq", int, optional = True, default = 1000, doc = doc_disp_freq),
@@ -411,7 +436,9 @@ def training_args():
411436
Argument("disp_training", bool, optional = True, default = True, doc = doc_disp_training),
412437
Argument("time_training", bool, optional = True, default = True, doc = doc_time_training),
413438
Argument("profiling", bool, optional = True, default = False, doc = doc_profiling),
414-
Argument("profiling_file", str, optional = True, default = 'timeline.json', doc = doc_profiling_file)
439+
Argument("profiling_file", str, optional = True, default = 'timeline.json', doc = doc_profiling_file),
440+
Argument("tensorboard", bool, optional = True, default = False, doc = doc_tensorboard),
441+
Argument("tensorboard_log_dir", str, optional = True, default = 'log', doc = doc_tensorboard_log_dir),
415442
]
416443

417444
doc_training = 'The training options'
@@ -493,14 +520,14 @@ def normalize(data):
493520
},
494521
495522
"learning_rate" :{
496-
"_type": "exp",
523+
"type": "exp",
497524
"decay_steps": 5000,
498525
"start_lr": 0.001,
499526
"stop_lr": 3.51e-8,
500527
"_comment": "that's all"
501-
},
528+
},
502529
503-
"loss" :{
530+
"loss" :{
504531
"start_pref_e": 0.02,
505532
"limit_pref_e": 1,
506533
"start_pref_f": 1000,
@@ -526,11 +553,11 @@ def normalize(data):
526553
"numb_test": 10,
527554
"save_freq": 1000,
528555
"save_ckpt": "model.ckpt",
529-
"_load_ckpt": "model.ckpt",
556+
"load_ckpt": "model.ckpt",
530557
"disp_training":true,
531558
"time_training":true,
532-
"_tensorboard": false,
533-
"_tensorboard_log_dir":"log",
559+
"tensorboard": false,
560+
"tensorboard_log_dir":"log",
534561
"profiling": false,
535562
"profiling_file":"timeline.json",
536563
"_comment": "that's all"

tests/test_checker.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,6 @@ def test_sub_variants(self):
258258
Argument("type2", dict)])
259259
])
260260

261-
def test_dpmd(self):
262-
import json
263-
from dpmdargs import check, example_json_str
264-
data = json.loads(example_json_str)
265-
check(data)
266-
# print("\n\n"+docstr)
267-
268261

269262
if __name__ == "__main__":
270263
unittest.main()

0 commit comments

Comments
 (0)