Skip to content

Commit e610703

Browse files
author
Han Wang
committed
provide options data_stat_nbatch and data_stat_protect as model parameters
1 parent be2f25e commit e610703

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

source/train/Model.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ def __init__ (self, jdata, descrpt, fitting):
2929
args = ClassArg()\
3030
.add('type_map', list, default = []) \
3131
.add('rcond', float, default = 1e-3) \
32+
.add('data_stat_nbatch', int, default = 10) \
33+
.add('data_stat_protect',float, default = 1e-2) \
3234
.add('use_srtab', str)
3335
class_data = args.parse(jdata)
3436
self.type_map = class_data['type_map']
3537
self.srtab_name = class_data['use_srtab']
3638
self.rcond = class_data['rcond']
39+
self.data_stat_nbatch = class_data['data_stat_nbatch']
40+
self.data_stat_protect = class_data['data_stat_protect']
3741
if self.srtab_name is not None :
3842
self.srtab = TabInter(self.srtab_name)
3943
args.add('smin_alpha', float, must = True)\
@@ -56,16 +60,16 @@ def get_ntypes (self) :
5660
def get_type_map (self) :
5761
return self.type_map
5862

59-
def data_stat(self, data, nbatch = 1, protection = 1e-2):
63+
def data_stat(self, data):
6064
all_stat = defaultdict(list)
6165
for ii in range(data.get_nsystems()) :
62-
for jj in range(nbatch) :
66+
for jj in range(self.data_stat_nbatch) :
6367
stat_data = data.get_batch (sys_idx = ii)
6468
for dd in stat_data:
6569
if dd == "natoms_vec":
6670
stat_data[dd] = stat_data[dd].astype(np.int32)
6771
all_stat[dd].append(stat_data[dd])
68-
self._compute_dstats (all_stat, protection = protection)
72+
self._compute_dstats (all_stat, protection = self.data_stat_protect)
6973
self.bias_atom_e = data.compute_energy_shift(self.rcond)
7074

7175

@@ -224,9 +228,11 @@ def __init__ (self, jdata, descrpt, fitting):
224228
self.fitting = fitting
225229

226230
args = ClassArg()\
227-
.add('type_map', list, default = [])
231+
.add('type_map', list, default = []) \
232+
.add('data_stat_nbatch', int, default = 10)
228233
class_data = args.parse(jdata)
229234
self.type_map = class_data['type_map']
235+
self.data_stat_nbatch = class_data['data_stat_nbatch']
230236

231237
def get_rcut (self) :
232238
return self.rcut
@@ -240,11 +246,12 @@ def get_type_map (self) :
240246
def data_stat(self, data):
241247
all_stat = defaultdict(list)
242248
for ii in range(data.get_nsystems()) :
243-
stat_data = data.get_batch (sys_idx = ii)
244-
for dd in stat_data:
245-
if dd == "natoms_vec":
246-
stat_data[dd] = stat_data[dd].astype(np.int32)
247-
all_stat[dd].append(stat_data[dd])
249+
for jj in range(self.data_stat_nbatch) :
250+
stat_data = data.get_batch (sys_idx = ii)
251+
for dd in stat_data:
252+
if dd == "natoms_vec":
253+
stat_data[dd] = stat_data[dd].astype(np.int32)
254+
all_stat[dd].append(stat_data[dd])
248255
self._compute_dstats(all_stat)
249256

250257

@@ -321,9 +328,11 @@ def __init__ (self, jdata, descrpt, fitting):
321328
self.fitting = fitting
322329

323330
args = ClassArg()\
324-
.add('type_map', list, default = [])
331+
.add('type_map', list, default = []) \
332+
.add('data_stat_nbatch', int, default = 10)
325333
class_data = args.parse(jdata)
326334
self.type_map = class_data['type_map']
335+
self.data_stat_nbatch = class_data['data_stat_nbatch']
327336

328337
def get_rcut (self) :
329338
return self.rcut
@@ -337,11 +346,12 @@ def get_type_map (self) :
337346
def data_stat(self, data):
338347
all_stat = defaultdict(list)
339348
for ii in range(data.get_nsystems()) :
340-
stat_data = data.get_batch (sys_idx = ii)
341-
for dd in stat_data:
342-
if dd == "natoms_vec":
343-
stat_data[dd] = stat_data[dd].astype(np.int32)
344-
all_stat[dd].append(stat_data[dd])
349+
for jj in range(self.data_stat_nbatch) :
350+
stat_data = data.get_batch (sys_idx = ii)
351+
for dd in stat_data:
352+
if dd == "natoms_vec":
353+
stat_data[dd] = stat_data[dd].astype(np.int32)
354+
all_stat[dd].append(stat_data[dd])
345355
self._compute_dstats (all_stat)
346356

347357

0 commit comments

Comments
 (0)