Skip to content

Commit 77eac60

Browse files
authored
Merge pull request #185 from amcadmus/devel
following the symlinks to find system dirs
2 parents 306d40b + 850ffce commit 77eac60

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

source/train/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import warnings
1+
import os,warnings,fnmatch
22
import numpy as np
33
from deepmd.env import tf
44

@@ -152,3 +152,9 @@ def get_activation_func(activation_fn):
152152
raise RuntimeError(activation_fn+" is not a valid activation function")
153153
return activation_fn_dict[activation_fn]
154154

155+
def expand_sys_str(root_dir):
156+
matches = []
157+
for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
158+
for filename in fnmatch.filter(filenames, 'type.raw'):
159+
matches.append(root)
160+
return matches

source/train/test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
from deepmd.Data import DeepmdData
10+
from deepmd.common import expand_sys_str
1011
from deepmd import DeepEval
1112
from deepmd import DeepPot
1213
from deepmd import DeepDipole
@@ -16,10 +17,7 @@
1617

1718
def test (args):
1819
de = DeepEval(args.model)
19-
all_sys = []
20-
from pathlib import Path
21-
for filename in Path(args.system).rglob('type.raw'):
22-
all_sys.append(os.path.dirname(filename))
20+
all_sys = expand_sys_str(args.system)
2321
for ii in all_sys:
2422
args.system = ii
2523
print ("# ---------------output of dp test--------------- ")

source/train/train.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from deepmd.RunOptions import RunOptions
1111
from deepmd.DataSystem import DeepmdDataSystem
1212
from deepmd.Trainer import NNPTrainer
13-
from deepmd.common import data_requirement
13+
from deepmd.common import data_requirement, expand_sys_str
1414
from deepmd.DataModifier import DipoleChargeModifier
1515

1616
def create_done_queue(cluster_spec, task_index):
@@ -80,13 +80,6 @@ def train (args) :
8080
# serial training
8181
_do_work(jdata, run_opt)
8282

83-
def expand_sys_str(root_dir):
84-
all_sys = []
85-
from pathlib import Path
86-
for filename in Path(root_dir).rglob('type.raw'):
87-
all_sys.append(os.path.dirname(filename))
88-
return all_sys
89-
9083
def _do_work(jdata, run_opt):
9184
# init the model
9285
model = NNPTrainer (jdata, run_opt = run_opt)

0 commit comments

Comments
 (0)