33import pathlib
44
55from deepmd .env import tf
6- from deepmd .env import GLOBAL_TF_FLOAT_PRECISION
76from deepmd .env import GLOBAL_NP_FLOAT_PRECISION
8- from deepmd .env import GLOBAL_ENER_FLOAT_PRECISION
97from deepmd .common import j_loader as dp_j_loader
8+ from deepmd .utils import random as dp_random
109
1110if GLOBAL_NP_FLOAT_PRECISION == np .float32 :
1211 global_default_fv_hh = 1e-2
@@ -26,8 +25,8 @@ def del_data():
2625 if os .path .isdir ('system' ):
2726 shutil .rmtree ('system' )
2827
29- def gen_data () :
30- tmpdata = Data (rand_pert = 0.1 , seed = 1 )
28+ def gen_data (nframes = 1 ) :
29+ tmpdata = Data (rand_pert = 0.1 , seed = 1 , nframes = nframes )
3130 sys = dpdata .LabeledSystem ()
3231 sys .data ['atom_names' ] = ['foo' , 'bar' ]
3332 sys .data ['coords' ] = tmpdata .coord
@@ -47,14 +46,15 @@ class Data():
4746 def __init__ (self ,
4847 rand_pert = 0.1 ,
4948 seed = 1 ,
50- box_scale = 20 ) :
49+ box_scale = 20 ,
50+ nframes = 1 ):
5151 coord = [[0.0 , 0.0 , 0.1 ], [1.1 , 0.0 , 0.1 ], [0.0 , 1.1 , 0.1 ],
5252 [4.0 , 0.0 , 0.0 ], [5.1 , 0.0 , 0.0 ], [4.0 , 1.1 , 0.0 ]]
53- self .nframes = 1
53+ self .nframes = nframes
5454 self .coord = np .array (coord )
5555 self .coord = self ._copy_nframes (self .coord )
56- np . random .seed (seed )
57- self .coord += rand_pert * np . random .random (self .coord .shape )
56+ dp_random .seed (seed )
57+ self .coord += rand_pert * dp_random .random (self .coord .shape )
5858 self .fparam = np .array ([[0.1 , 0.2 ]])
5959 self .aparam = np .tile (self .fparam , [1 , 6 ])
6060 self .fparam = self ._copy_nframes (self .fparam )
@@ -69,7 +69,7 @@ def __init__ (self,
6969 self .coord = self .coord .reshape ([self .nframes , - 1 , 3 ])
7070 self .coord = self .coord [:,self .idx_map ,:]
7171 self .coord = self .coord .reshape ([self .nframes , - 1 ])
72- self .efield = np . random .random (self .coord .shape )
72+ self .efield = dp_random .random (self .coord .shape )
7373 self .atype = self .atype [self .idx_map ]
7474 self .datype = self ._copy_nframes (self .atype )
7575
@@ -128,7 +128,7 @@ def get_test_box_data (self,
128128 coord0_ , box0_ , type0_ = self .get_data ()
129129 coord = coord0_ [0 ]
130130 box = box0_ [0 ]
131- box += rand_pert * np . random .random (box .shape )
131+ box += rand_pert * dp_random .random (box .shape )
132132 atype = type0_ [0 ]
133133 nframes = 1
134134 natoms = coord .size // 3
0 commit comments