|
| 1 | +import os,json,warnings |
| 2 | +from deepmd.common import j_have,j_must_have,j_must_have_d |
| 3 | + |
| 4 | +def convert_input_v0_v1(jdata, warning = True, dump = None) : |
| 5 | + output = {} |
| 6 | + if 'with_distrib' in jdata: |
| 7 | + output['with_distrib'] = jdata['with_distrib'] |
| 8 | + if jdata['use_smooth'] : |
| 9 | + output['model'] = _smth_model(jdata) |
| 10 | + else: |
| 11 | + output['model'] = _nonsmth_model(jdata) |
| 12 | + output['learning_rate'] = _learning_rate(jdata) |
| 13 | + output['loss'] = _loss(jdata) |
| 14 | + output['training'] = _training(jdata) |
| 15 | + _warnning_input_v0_v1(dump) |
| 16 | + if dump is not None: |
| 17 | + with open(dump, 'w') as fp: |
| 18 | + json.dump(output, fp, indent=4) |
| 19 | + return output |
| 20 | + |
| 21 | +def _warnning_input_v0_v1(fname) : |
| 22 | + msg = 'It seems that you are using a deepmd-kit input of version 0.x.x, which is deprecated. we have converted the input to >1.0.0 compatible' |
| 23 | + if fname is not None: |
| 24 | + msg += ', and output it to file ' + fname |
| 25 | + warnings.warn(msg) |
| 26 | + |
| 27 | +def _nonsmth_model(jdata): |
| 28 | + model = {} |
| 29 | + model['descriptor'] = _nonsmth_descriptor(jdata) |
| 30 | + model['fitting_net'] = _fitting_net(jdata) |
| 31 | + return model |
| 32 | + |
| 33 | +def _smth_model(jdata): |
| 34 | + model = {} |
| 35 | + model['descriptor'] = _smth_descriptor(jdata) |
| 36 | + model['fitting_net'] = _fitting_net(jdata) |
| 37 | + return model |
| 38 | + |
| 39 | +def _nonsmth_descriptor(jdata) : |
| 40 | + output = {} |
| 41 | + seed = None |
| 42 | + if j_have (jdata, 'seed') : |
| 43 | + seed = jdata['seed'] |
| 44 | + # model |
| 45 | + descriptor = {} |
| 46 | + descriptor['type'] = 'loc_frame' |
| 47 | + descriptor['sel_a'] = jdata['sel_a'] |
| 48 | + descriptor['sel_r'] = jdata['sel_r'] |
| 49 | + descriptor['rcut'] = jdata['rcut'] |
| 50 | + descriptor['axis_rule'] = jdata['axis_rule'] |
| 51 | + return descriptor |
| 52 | + |
| 53 | +def _smth_descriptor(jdata): |
| 54 | + descriptor = {} |
| 55 | + seed = None |
| 56 | + if j_have (jdata, 'seed') : |
| 57 | + seed = jdata['seed'] |
| 58 | + descriptor['type'] = 'se_a' |
| 59 | + descriptor['sel'] = jdata['sel_a'] |
| 60 | + descriptor['rcut'] = jdata['rcut'] |
| 61 | + if j_have(jdata, 'rcut_smth') : |
| 62 | + descriptor['rcut_r_smth'] = jdata['rcut_smth'] |
| 63 | + else : |
| 64 | + descriptor['rcut_r_smth'] = descriptor['rcut'] |
| 65 | + descriptor['neuron'] = j_must_have (jdata, 'filter_neuron') |
| 66 | + descriptor['axis_neuron'] = j_must_have_d (jdata, 'axis_neuron', ['n_axis_neuron']) |
| 67 | + descriptor['resnet_dt'] = False |
| 68 | + if j_have(jdata, 'resnet_dt') : |
| 69 | + descriptor['resnet_dt'] = jdata['filter_resnet_dt'] |
| 70 | + if seed is not None: |
| 71 | + descriptor['seed'] = seed |
| 72 | + return descriptor |
| 73 | + |
| 74 | +def _fitting_net(jdata): |
| 75 | + fitting_net = {} |
| 76 | + seed = None |
| 77 | + if j_have (jdata, 'seed') : |
| 78 | + seed = jdata['seed'] |
| 79 | + fitting_net['neuron']= j_must_have_d (jdata, 'fitting_neuron', ['n_neuron']) |
| 80 | + fitting_net['resnet_dt'] = True |
| 81 | + if j_have(jdata, 'resnet_dt') : |
| 82 | + fitting_net['resnet_dt'] = jdata['resnet_dt'] |
| 83 | + if j_have(jdata, 'fitting_resnet_dt') : |
| 84 | + fitting_net['resnet_dt'] = jdata['fitting_resnet_dt'] |
| 85 | + if seed is not None: |
| 86 | + fitting_net['seed'] = seed |
| 87 | + return fitting_net |
| 88 | + |
| 89 | +def _learning_rate(jdata): |
| 90 | + # learning rate |
| 91 | + learning_rate = {} |
| 92 | + learning_rate['type'] = 'exp' |
| 93 | + learning_rate['decay_steps'] = j_must_have(jdata, 'decay_steps') |
| 94 | + learning_rate['decay_rate'] = j_must_have(jdata, 'decay_rate') |
| 95 | + learning_rate['start_lr'] = j_must_have(jdata, 'start_lr') |
| 96 | + return learning_rate |
| 97 | + |
| 98 | +def _loss(jdata): |
| 99 | + # loss |
| 100 | + loss = {} |
| 101 | + loss['start_pref_e'] = j_must_have (jdata, 'start_pref_e') |
| 102 | + loss['limit_pref_e'] = j_must_have (jdata, 'limit_pref_e') |
| 103 | + loss['start_pref_f'] = j_must_have (jdata, 'start_pref_f') |
| 104 | + loss['limit_pref_f'] = j_must_have (jdata, 'limit_pref_f') |
| 105 | + loss['start_pref_v'] = j_must_have (jdata, 'start_pref_v') |
| 106 | + loss['limit_pref_v'] = j_must_have (jdata, 'limit_pref_v') |
| 107 | + if j_have(jdata, 'start_pref_ae') : |
| 108 | + loss['start_pref_ae'] = jdata['start_pref_ae'] |
| 109 | + if j_have(jdata, 'limit_pref_ae') : |
| 110 | + loss['limit_pref_ae'] = jdata['limit_pref_ae'] |
| 111 | + return loss |
| 112 | + |
| 113 | +def _training(jdata): |
| 114 | + # training |
| 115 | + training = {} |
| 116 | + seed = None |
| 117 | + if j_have (jdata, 'seed') : |
| 118 | + seed = jdata['seed'] |
| 119 | + training['systems'] = jdata['systems'] |
| 120 | + training['set_prefix'] = jdata['set_prefix'] |
| 121 | + training['stop_batch'] = jdata['stop_batch'] |
| 122 | + training['batch_size'] = jdata['batch_size'] |
| 123 | + if seed is not None: |
| 124 | + training['seed'] = seed |
| 125 | + training['disp_file'] = "lcurve.out" |
| 126 | + if j_have (jdata, "disp_file") : training['disp_file'] = jdata["disp_file"] |
| 127 | + training['disp_freq'] = j_must_have (jdata, 'disp_freq') |
| 128 | + training['numb_test'] = j_must_have (jdata, 'numb_test') |
| 129 | + training['save_freq'] = j_must_have (jdata, 'save_freq') |
| 130 | + training['save_ckpt'] = j_must_have (jdata, 'save_ckpt') |
| 131 | + training['display_in_training'] = j_must_have (jdata, 'disp_training') |
| 132 | + training['timing_in_training'] = j_must_have (jdata, 'time_training') |
| 133 | + training['profiling'] = False |
| 134 | + if j_have (jdata, 'profiling') : |
| 135 | + training['profiling'] = jdata['profiling'] |
| 136 | + if training['profiling'] : |
| 137 | + training['profiling_file'] = j_must_have (jdata, 'profiling_file') |
| 138 | + return training |
0 commit comments