@@ -29,11 +29,15 @@ def __init__ (self, jdata, descrpt, fitting):
2929 args = ClassArg ()\
3030 .add ('type_map' , list , default = []) \
3131 .add ('rcond' , float , default = 1e-3 ) \
32+ .add ('data_stat_nbatch' , int , default = 10 ) \
33+ .add ('data_stat_protect' ,float , default = 1e-2 ) \
3234 .add ('use_srtab' , str )
3335 class_data = args .parse (jdata )
3436 self .type_map = class_data ['type_map' ]
3537 self .srtab_name = class_data ['use_srtab' ]
3638 self .rcond = class_data ['rcond' ]
39+ self .data_stat_nbatch = class_data ['data_stat_nbatch' ]
40+ self .data_stat_protect = class_data ['data_stat_protect' ]
3741 if self .srtab_name is not None :
3842 self .srtab = TabInter (self .srtab_name )
3943 args .add ('smin_alpha' , float , must = True )\
@@ -56,16 +60,16 @@ def get_ntypes (self) :
5660 def get_type_map (self ) :
5761 return self .type_map
5862
59- def data_stat (self , data , nbatch = 1 , protection = 1e-2 ):
63+ def data_stat (self , data ):
6064 all_stat = defaultdict (list )
6165 for ii in range (data .get_nsystems ()) :
62- for jj in range (nbatch ) :
66+ for jj in range (self . data_stat_nbatch ) :
6367 stat_data = data .get_batch (sys_idx = ii )
6468 for dd in stat_data :
6569 if dd == "natoms_vec" :
6670 stat_data [dd ] = stat_data [dd ].astype (np .int32 )
6771 all_stat [dd ].append (stat_data [dd ])
68- self ._compute_dstats (all_stat , protection = protection )
72+ self ._compute_dstats (all_stat , protection = self . data_stat_protect )
6973 self .bias_atom_e = data .compute_energy_shift (self .rcond )
7074
7175
@@ -224,9 +228,11 @@ def __init__ (self, jdata, descrpt, fitting):
224228 self .fitting = fitting
225229
226230 args = ClassArg ()\
227- .add ('type_map' , list , default = [])
231+ .add ('type_map' , list , default = []) \
232+ .add ('data_stat_nbatch' , int , default = 10 )
228233 class_data = args .parse (jdata )
229234 self .type_map = class_data ['type_map' ]
235+ self .data_stat_nbatch = class_data ['data_stat_nbatch' ]
230236
231237 def get_rcut (self ) :
232238 return self .rcut
@@ -240,11 +246,12 @@ def get_type_map (self) :
240246 def data_stat (self , data ):
241247 all_stat = defaultdict (list )
242248 for ii in range (data .get_nsystems ()) :
243- stat_data = data .get_batch (sys_idx = ii )
244- for dd in stat_data :
245- if dd == "natoms_vec" :
246- stat_data [dd ] = stat_data [dd ].astype (np .int32 )
247- all_stat [dd ].append (stat_data [dd ])
249+ for jj in range (self .data_stat_nbatch ) :
250+ stat_data = data .get_batch (sys_idx = ii )
251+ for dd in stat_data :
252+ if dd == "natoms_vec" :
253+ stat_data [dd ] = stat_data [dd ].astype (np .int32 )
254+ all_stat [dd ].append (stat_data [dd ])
248255 self ._compute_dstats (all_stat )
249256
250257
@@ -321,9 +328,11 @@ def __init__ (self, jdata, descrpt, fitting):
321328 self .fitting = fitting
322329
323330 args = ClassArg ()\
324- .add ('type_map' , list , default = [])
331+ .add ('type_map' , list , default = []) \
332+ .add ('data_stat_nbatch' , int , default = 10 )
325333 class_data = args .parse (jdata )
326334 self .type_map = class_data ['type_map' ]
335+ self .data_stat_nbatch = class_data ['data_stat_nbatch' ]
327336
328337 def get_rcut (self ) :
329338 return self .rcut
@@ -337,11 +346,12 @@ def get_type_map (self) :
337346 def data_stat (self , data ):
338347 all_stat = defaultdict (list )
339348 for ii in range (data .get_nsystems ()) :
340- stat_data = data .get_batch (sys_idx = ii )
341- for dd in stat_data :
342- if dd == "natoms_vec" :
343- stat_data [dd ] = stat_data [dd ].astype (np .int32 )
344- all_stat [dd ].append (stat_data [dd ])
349+ for jj in range (self .data_stat_nbatch ) :
350+ stat_data = data .get_batch (sys_idx = ii )
351+ for dd in stat_data :
352+ if dd == "natoms_vec" :
353+ stat_data [dd ] = stat_data [dd ].astype (np .int32 )
354+ all_stat [dd ].append (stat_data [dd ])
345355 self ._compute_dstats (all_stat )
346356
347357
0 commit comments