|
| 1 | +import dpdata,os,sys,json,unittest |
| 2 | +import numpy as np |
| 3 | +from deepmd.env import tf |
| 4 | +from common import Data,gen_data |
| 5 | + |
| 6 | +from deepmd.RunOptions import RunOptions |
| 7 | +from deepmd.DataSystem import DataSystem |
| 8 | +from deepmd.DescrptLocFrame import DescrptLocFrame |
| 9 | +from deepmd.Fitting import WFCFitting |
| 10 | +from deepmd.Model import WFCModel |
| 11 | +from deepmd.common import j_must_have, j_must_have_d, j_have |
| 12 | + |
| 13 | +global_ener_float_precision = tf.float64 |
| 14 | +global_tf_float_precision = tf.float64 |
| 15 | +global_np_float_precision = np.float64 |
| 16 | + |
| 17 | +class TestModel(unittest.TestCase): |
| 18 | + def setUp(self) : |
| 19 | + gen_data() |
| 20 | + |
| 21 | + def test_model(self): |
| 22 | + jfile = 'wfc.json' |
| 23 | + with open(jfile) as fp: |
| 24 | + jdata = json.load (fp) |
| 25 | + run_opt = RunOptions(None) |
| 26 | + systems = j_must_have(jdata, 'systems') |
| 27 | + set_pfx = j_must_have(jdata, 'set_prefix') |
| 28 | + batch_size = j_must_have(jdata, 'batch_size') |
| 29 | + test_size = j_must_have(jdata, 'numb_test') |
| 30 | + batch_size = 1 |
| 31 | + test_size = 1 |
| 32 | + stop_batch = j_must_have(jdata, 'stop_batch') |
| 33 | + rcut = j_must_have (jdata['model']['descriptor'], 'rcut') |
| 34 | + |
| 35 | + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt = None) |
| 36 | + |
| 37 | + test_data = data.get_test () |
| 38 | + numb_test = 1 |
| 39 | + |
| 40 | + descrpt = DescrptLocFrame(jdata['model']['descriptor']) |
| 41 | + fitting = WFCFitting(jdata['model']['fitting_net'], descrpt) |
| 42 | + model = WFCModel(jdata['model'], descrpt, fitting) |
| 43 | + |
| 44 | + input_data = {'coord' : [test_data['coord']], |
| 45 | + 'box': [test_data['box']], |
| 46 | + 'type': [test_data['type']], |
| 47 | + 'natoms_vec' : [test_data['natoms_vec']], |
| 48 | + 'default_mesh' : [test_data['default_mesh']], |
| 49 | + 'fparam': [test_data['fparam']], |
| 50 | + } |
| 51 | + model._compute_dstats(input_data) |
| 52 | + |
| 53 | + t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c') |
| 54 | + t_energy = tf.placeholder(global_ener_float_precision, [None], name='t_energy') |
| 55 | + t_force = tf.placeholder(global_tf_float_precision, [None], name='t_force') |
| 56 | + t_virial = tf.placeholder(global_tf_float_precision, [None], name='t_virial') |
| 57 | + t_atom_ener = tf.placeholder(global_tf_float_precision, [None], name='t_atom_ener') |
| 58 | + t_coord = tf.placeholder(global_tf_float_precision, [None], name='i_coord') |
| 59 | + t_type = tf.placeholder(tf.int32, [None], name='i_type') |
| 60 | + t_natoms = tf.placeholder(tf.int32, [model.ntypes+2], name='i_natoms') |
| 61 | + t_box = tf.placeholder(global_tf_float_precision, [None, 9], name='i_box') |
| 62 | + t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh') |
| 63 | + is_training = tf.placeholder(tf.bool) |
| 64 | + t_fparam = None |
| 65 | + |
| 66 | + model_pred \ |
| 67 | + = model.build (t_coord, |
| 68 | + t_type, |
| 69 | + t_natoms, |
| 70 | + t_box, |
| 71 | + t_mesh, |
| 72 | + t_fparam, |
| 73 | + suffix = "wfc", |
| 74 | + reuse = False) |
| 75 | + wfc = model_pred['wfc'] |
| 76 | + |
| 77 | + feed_dict_test = {t_prop_c: test_data['prop_c'], |
| 78 | + t_coord: np.reshape(test_data['coord'] [:numb_test, :], [-1]), |
| 79 | + t_box: test_data['box'] [:numb_test, :], |
| 80 | + t_type: np.reshape(test_data['type'] [:numb_test, :], [-1]), |
| 81 | + t_natoms: test_data['natoms_vec'], |
| 82 | + t_mesh: test_data['default_mesh'], |
| 83 | + is_training: False} |
| 84 | + |
| 85 | + sess = tf.Session() |
| 86 | + sess.run(tf.global_variables_initializer()) |
| 87 | + [p] = sess.run([wfc], feed_dict = feed_dict_test) |
| 88 | + |
| 89 | + p = p.reshape([-1]) |
| 90 | + refp = [-9.105016838228578990e-01,7.196284362034099935e-01,-9.548516928185298014e-02,2.764615027095288724e+00,2.661319598995644520e-01,7.579512949131941846e-02,-2.107409067376114997e+00,-1.299080016614967414e-01,-5.962778584850070285e-01,2.913899917663253514e-01,-1.226917174638697094e+00,1.829523069930876655e+00,1.015704024959750873e+00,-1.792333611099589386e-01,5.032898080485321834e-01,1.808561721292949453e-01,2.468863482075112081e+00,-2.566442546384765100e-01,-1.467453783795173994e-01,-1.822963931552128658e+00,5.843600156865462747e-01,-1.493875280832117403e+00,1.693322352814763398e-01,-1.877325443995481624e+00] |
| 91 | + |
| 92 | + places = 6 |
| 93 | + for ii in range(p.size) : |
| 94 | + self.assertAlmostEqual(p[ii], refp[ii], places = places) |
| 95 | + |
| 96 | + |
| 97 | + |
0 commit comments