Skip to content

Commit c5f412f

Browse files
committed
support MultiSystems from_dir
1 parent 3d8a9aa commit c5f412f

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

dpdata/system.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,20 @@ def from_file(cls,file_name,fmt):
950950
multi_systems = cls()
951951
multi_systems.load_systems_from_file(file_name=file_name,fmt=fmt)
952952
return multi_systems
953-
953+
954+
@classmethod
955+
def from_dir(cls,dir_name, file_name, fmt='auto'):
956+
multi_systems = cls()
957+
target_file_list = []
958+
for ii in os.walk(dir_name):
959+
if file_name in ii[2]:
960+
target_file = os.path.join(ii[0], file_name)
961+
target_file_list.append(target_file)
962+
for target_file in target_file_list:
963+
multi_systems.append(LabeledSystem(file_name=target_file, fmt=fmt))
964+
return multi_systems
965+
966+
954967
def load_systems_from_file(self, file_name=None, fmt=None):
955968
if file_name is not None:
956969
if fmt is None:
@@ -1015,10 +1028,10 @@ def check_atom_names(self, system):
10151028
system.add_atom_names(new_in_self)
10161029
system.sort_atom_names()
10171030

1018-
def from_quip_gap_xyz_file(self,filename):
1019-
# quip_gap_xyz_systems = QuipGapxyzSystems(filename)
1031+
def from_quip_gap_xyz_file(self,file_name):
1032+
# quip_gap_xyz_systems = QuipGapxyzSystems(file_name)
10201033
# print(next(quip_gap_xyz_systems))
1021-
for info_dict in QuipGapxyzSystems(filename):
1034+
for info_dict in QuipGapxyzSystems(file_name):
10221035
system=LabeledSystem(data=info_dict)
10231036
self.append(system)
10241037

dpdata/xyz/quip_gap_xyz.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
#!/usr/bin/env python3
2-
3-
#%%
4-
# with open('./test.xyz', 'r') as xyz_file:
5-
# lines = xyz_file.readlines()
6-
# print(lines)
72
#%%
83
import numpy as np
94
from collections import OrderedDict
@@ -44,20 +39,16 @@ def handle_single_xyz_frame(lines):
4439
if len(lines) != atom_num + 2:
4540
raise RuntimeError("format error, atom_num=={}, {}!=atom_num+2".format(atom_num, len(lines)))
4641
data_format_line = lines[1].strip('\n').strip()+str(' ')
47-
p1 = re.compile(r'(?P<key>\S+)=(?P<quote>[\'\"]?)(?P<value>.*?)(?P=quote)\s+')
48-
p2 = re.compile(r'(?P<key>\w+?):(?P<datatype>[a-zA-Z]):(?P<value>\d+)')
49-
field_list = [kv_dict.groupdict() for kv_dict in p1.finditer(data_format_line)]
50-
field_dict = {}
51-
for item in field_list:
52-
field_dict[item['key']]=item['value']
53-
data_format_line = lines[1]
54-
data_format_list= [m.groupdict() for m in p1.finditer(data_format_line)]
42+
field_value_pattern= re.compile(r'(?P<key>\S+)=(?P<quote>[\'\"]?)(?P<value>.*?)(?P=quote)\s+')
43+
prop_pattern = re.compile(r'(?P<key>\w+?):(?P<datatype>[a-zA-Z]):(?P<value>\d+)')
44+
45+
data_format_list= [kv_dict.groupdict() for kv_dict in field_value_pattern.finditer(data_format_line)]
5546
field_dict = {}
5647
for item in data_format_list:
5748
field_dict[item['key']]=item['value']
5849

5950
Properties = field_dict['Properties']
60-
prop_list = [m.groupdict() for m in p2.finditer(Properties)]
51+
prop_list = [kv_dict.groupdict() for kv_dict in prop_pattern.finditer(Properties)]
6152

6253
data_lines = []
6354
for line in lines[2:]:
@@ -127,6 +118,7 @@ def handle_single_xyz_frame(lines):
127118
virials = np.array([np.array(list(filter(bool,field_dict['virial'].split(' ')))).reshape(3,3)]).astype('float32')
128119
else:
129120
virials = None
121+
130122
info_dict = {}
131123
info_dict['atom_names'] = list(type_num_array[:,0])
132124
info_dict['atom_numbs'] = list(type_num_array[:,1].astype(int))

0 commit comments

Comments
 (0)