Skip to content

Commit 8727956

Browse files
authored
merge duplicated NeighborStat.get_stat (#1103)
Note: this is a simple fix to resolve #1088, but I think we should design a clear architecture to call neighbor stat. This should reduce the half of the time, but it may be still too long. We can consider some better algorithm to calculate neighbour stat (like KDtree?) for further optimization.
1 parent d5a0f6f commit 8727956

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

deepmd/entrypoints/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Dict, List, Optional, Any
1111

1212
from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
13-
from deepmd.env import tf, reset_default_tf_session_config
13+
from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_TF_FLOAT_PRECISION
1414
from deepmd.infer.data_modifier import DipoleChargeModifier
1515
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
1616
from deepmd.train.trainer import DPTrainer
@@ -262,6 +262,16 @@ def get_nbor_stat(jdata, rcut):
262262
neistat = NeighborStat(ntypes, rcut)
263263

264264
min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)
265+
266+
# moved from traier.py as duplicated
267+
# TODO: this is a simple fix but we should have a clear
268+
# architecture to call neighbor stat
269+
tf.constant(min_nbor_dist,
270+
name = 'train_attr/min_nbor_dist',
271+
dtype = GLOBAL_TF_FLOAT_PRECISION)
272+
tf.constant(max_nbor_size,
273+
name = 'train_attr/max_nbor_size',
274+
dtype = GLOBAL_TF_FLOAT_PRECISION)
265275
return min_nbor_dist, max_nbor_size
266276

267277
def get_sel(jdata, rcut):

deepmd/train/trainer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,9 @@ def build (self,
313313
if self.run_opt.init_mode == 'init_from_frz_model':
314314
self._init_from_frz_model()
315315

316-
self.neighbor_stat \
317-
= NeighborStat(self.ntypes, self.descrpt.get_rcut())
318-
self.min_nbor_dist, self.max_nbor_size \
319-
= self.neighbor_stat.get_stat(data)
320-
tf.constant(self.min_nbor_dist,
321-
name = 'train_attr/min_nbor_dist',
322-
dtype = GLOBAL_TF_FLOAT_PRECISION)
323-
tf.constant(self.max_nbor_size,
324-
name = 'train_attr/max_nbor_size',
325-
dtype = GLOBAL_TF_FLOAT_PRECISION)
316+
# neighbor_stat is moved to train.py as duplicated
317+
# TODO: this is a simple fix but we should have a clear
318+
# architecture to call neighbor stat
326319
else :
327320
self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3])
328321
self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file']))

0 commit comments

Comments
 (0)