Skip to content

Commit 6bf41e3

Browse files
feat: file object passed to open (#709)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new utility function `open_file` for improved file handling across various modules. - Enhanced type annotations for multiple functions to specify `FileType`, improving code clarity and type safety. - **Bug Fixes** - Improved file handling robustness by replacing the built-in `open` function with the custom `open_file` function in several modules. - **Tests** - Added unit tests for the new `open_file` utility function to ensure reliable functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5df6acd commit 6bf41e3

File tree

29 files changed

+289
-90
lines changed

29 files changed

+289
-90
lines changed

dpdata/abacus/md.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy as np
77

8+
from dpdata.utils import open_file
9+
810
from .scf import (
911
bohr2ang,
1012
get_cell,
@@ -156,12 +158,12 @@ def get_frame(fname):
156158
path_in = os.path.join(fname, "INPUT")
157159
else:
158160
raise RuntimeError("invalid input")
159-
with open(path_in) as fp:
161+
with open_file(path_in) as fp:
160162
inlines = fp.read().split("\n")
161163
geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU
162164
path_out = get_path_out(fname, inlines)
163165

164-
with open(geometry_path_in) as fp:
166+
with open_file(geometry_path_in) as fp:
165167
geometry_inlines = fp.read().split("\n")
166168
celldm, cell = get_cell(geometry_inlines)
167169
atom_names, natoms, types, coords = get_coords(
@@ -172,11 +174,11 @@ def get_frame(fname):
172174
# ndump = int(os.popen("ls -l %s | grep 'md_pos_' | wc -l" %path_out).readlines()[0])
173175
# number of dumped geometry files
174176
# coords = get_coords_from_cif(ndump, dump_freq, atom_names, natoms, types, path_out, cell)
175-
with open(os.path.join(path_out, "MD_dump")) as fp:
177+
with open_file(os.path.join(path_out, "MD_dump")) as fp:
176178
dumplines = fp.read().split("\n")
177179
coords, cells, force, stress = get_coords_from_dump(dumplines, natoms)
178180
ndump = np.shape(coords)[0]
179-
with open(os.path.join(path_out, "running_md.log")) as fp:
181+
with open_file(os.path.join(path_out, "running_md.log")) as fp:
180182
outlines = fp.read().split("\n")
181183
energy = get_energy(outlines, ndump, dump_freq)
182184

dpdata/abacus/relax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import numpy as np
66

7+
from dpdata.utils import open_file
8+
79
from .scf import (
810
bohr2ang,
911
collect_force,
@@ -174,10 +176,10 @@ def get_frame(fname):
174176
path_in = os.path.join(fname, "INPUT")
175177
else:
176178
raise RuntimeError("invalid input")
177-
with open(path_in) as fp:
179+
with open_file(path_in) as fp:
178180
inlines = fp.read().split("\n")
179181
geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU
180-
with open(geometry_path_in) as fp:
182+
with open_file(geometry_path_in) as fp:
181183
geometry_inlines = fp.read().split("\n")
182184
celldm, cell = get_cell(geometry_inlines)
183185
atom_names, natoms, types, coord_tmp = get_coords(
@@ -186,7 +188,7 @@ def get_frame(fname):
186188

187189
logf = get_log_file(fname, inlines)
188190
assert os.path.isfile(logf), f"Error: can not find {logf}"
189-
with open(logf) as f1:
191+
with open_file(logf) as f1:
190192
lines = f1.readlines()
191193

192194
atomnumber = 0

dpdata/abacus/scf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88

9+
from dpdata.utils import open_file
10+
911
from ..unit import EnergyConversion, LengthConversion, PressureConversion
1012

1113
bohr2ang = LengthConversion("bohr", "angstrom").value()
@@ -253,17 +255,17 @@ def get_frame(fname):
253255
if not CheckFile(path_in):
254256
return data
255257

256-
with open(path_in) as fp:
258+
with open_file(path_in) as fp:
257259
inlines = fp.read().split("\n")
258260

259261
geometry_path_in = get_geometry_in(fname, inlines)
260262
path_out = get_path_out(fname, inlines)
261263
if not (CheckFile(geometry_path_in) and CheckFile(path_out)):
262264
return data
263265

264-
with open(geometry_path_in) as fp:
266+
with open_file(geometry_path_in) as fp:
265267
geometry_inlines = fp.read().split("\n")
266-
with open(path_out) as fp:
268+
with open_file(path_out) as fp:
267269
outlines = fp.read().split("\n")
268270

269271
celldm, cell = get_cell(geometry_inlines)
@@ -338,7 +340,7 @@ def get_nele_from_stru(geometry_inlines):
338340

339341
def get_frame_from_stru(fname):
340342
assert isinstance(fname, str)
341-
with open(fname) as fp:
343+
with open_file(fname) as fp:
342344
geometry_inlines = fp.read().split("\n")
343345
nele = get_nele_from_stru(geometry_inlines)
344346
inlines = ["ntype %d" % nele]

dpdata/amber/md.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from dpdata.amber.mask import pick_by_amber_mask
99
from dpdata.unit import EnergyConversion
10+
from dpdata.utils import open_file
1011

1112
from ..periodic_table import ELEMENTS
1213

@@ -51,7 +52,7 @@ def read_amber_traj(
5152
flag_atom_numb = False
5253
amber_types = []
5354
atomic_number = []
54-
with open(parm7_file) as f:
55+
with open_file(parm7_file) as f:
5556
for line in f:
5657
if line.startswith("%FLAG"):
5758
flag_atom_type = line.startswith("%FLAG AMBER_ATOM_TYPE")
@@ -101,14 +102,14 @@ def read_amber_traj(
101102
# load energy from mden_file or mdout_file
102103
energies = []
103104
if mden_file is not None and os.path.isfile(mden_file):
104-
with open(mden_file) as f:
105+
with open_file(mden_file) as f:
105106
for line in f:
106107
if line.startswith("L6"):
107108
s = line.split()
108109
if s[2] != "E_pot":
109110
energies.append(float(s[2]))
110111
elif mdout_file is not None and os.path.isfile(mdout_file):
111-
with open(mdout_file) as f:
112+
with open_file(mdout_file) as f:
112113
for line in f:
113114
if "EPtot" in line:
114115
s = line.split()

dpdata/amber/sqm.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import numpy as np
46

57
from dpdata.periodic_table import ELEMENTS
68
from dpdata.unit import EnergyConversion
9+
from dpdata.utils import open_file
10+
11+
if TYPE_CHECKING:
12+
from dpdata.utils import FileType
713

814
kcal2ev = EnergyConversion("kcal_mol", "eV").value()
915

@@ -14,15 +20,15 @@
1420
READ_FORCES = 7
1521

1622

17-
def parse_sqm_out(fname):
23+
def parse_sqm_out(fname: FileType):
1824
"""Read atom symbols, charges and coordinates from ambertools sqm.out file."""
1925
atom_symbols = []
2026
coords = []
2127
charges = []
2228
forces = []
2329
energies = []
2430

25-
with open(fname) as f:
31+
with open_file(fname) as f:
2632
flag = START
2733
for line in f:
2834
if line.startswith(" Total SCF energy"):
@@ -81,7 +87,7 @@ def parse_sqm_out(fname):
8187
return data
8288

8389

84-
def make_sqm_in(data, fname=None, frame_idx=0, **kwargs):
90+
def make_sqm_in(data, fname: FileType | None = None, frame_idx=0, **kwargs):
8591
symbols = [data["atom_names"][ii] for ii in data["atom_types"]]
8692
atomic_numbers = [ELEMENTS.index(ss) + 1 for ss in symbols]
8793
charge = kwargs.get("charge", 0)
@@ -109,6 +115,6 @@ def make_sqm_in(data, fname=None, frame_idx=0, **kwargs):
109115
f"{data['coords'][frame_idx][ii, 2]:.6f}",
110116
)
111117
if fname is not None:
112-
with open(fname, "w") as fp:
118+
with open_file(fname, "w") as fp:
113119
fp.write(ret)
114120
return ret

dpdata/deepmd/comp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99

1010
import dpdata
11+
from dpdata.utils import open_file
1112

1213
from .raw import load_type
1314

@@ -172,7 +173,7 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
172173
except OSError:
173174
pass
174175
if data.get("nopbc", False):
175-
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
176+
with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
176177
pass
177178
# allow custom dtypes
178179
labels = "energies" in data

dpdata/deepmd/raw.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
import dpdata
9+
from dpdata.utils import open_file
910

1011

1112
def load_type(folder, type_map=None):
@@ -17,7 +18,7 @@ def load_type(folder, type_map=None):
1718
data["atom_names"] = []
1819
# if find type_map.raw, use it
1920
if os.path.isfile(os.path.join(folder, "type_map.raw")):
20-
with open(os.path.join(folder, "type_map.raw")) as fp:
21+
with open_file(os.path.join(folder, "type_map.raw")) as fp:
2122
my_type_map = fp.read().split()
2223
# else try to use arg type_map
2324
elif type_map is not None:
@@ -140,7 +141,7 @@ def dump(folder, data):
140141
except OSError:
141142
pass
142143
if data.get("nopbc", False):
143-
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
144+
with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
144145
pass
145146
# allow custom dtypes
146147
labels = "energies" in data

dpdata/dftbplus/output.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import numpy as np
46

7+
from dpdata.utils import open_file
8+
9+
if TYPE_CHECKING:
10+
from dpdata.utils import FileType
11+
512

6-
def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]:
13+
def read_dftb_plus(
14+
fn_1: FileType, fn_2: FileType
15+
) -> tuple[str, np.ndarray, float, np.ndarray]:
716
"""Read from DFTB+ input and output.
817
918
Parameters
@@ -29,7 +38,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda
2938
symbols = None
3039
forces = None
3140
energy = None
32-
with open(fn_1) as f:
41+
with open_file(fn_1) as f:
3342
flag = 0
3443
for line in f:
3544
if flag == 1:
@@ -49,7 +58,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda
4958
flag += 1
5059
if flag == 7:
5160
flag = 0
52-
with open(fn_2) as f:
61+
with open_file(fn_2) as f:
5362
flag = 0
5463
for line in f:
5564
if line.startswith("Total Forces"):

dpdata/gaussian/log.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import numpy as np
46

7+
from dpdata.utils import open_file
8+
9+
if TYPE_CHECKING:
10+
from dpdata.utils import FileType
11+
512
from ..periodic_table import ELEMENTS
613
from ..unit import EnergyConversion, ForceConversion, LengthConversion
714

@@ -12,7 +19,7 @@
1219
symbols = ["X"] + ELEMENTS
1320

1421

15-
def to_system_data(file_name, md=False):
22+
def to_system_data(file_name: FileType, md=False):
1623
"""Read Gaussian log file.
1724
1825
Parameters
@@ -43,7 +50,7 @@ def to_system_data(file_name, md=False):
4350
nopbc = True
4451
coords = None
4552

46-
with open(file_name) as fp:
53+
with open_file(file_name) as fp:
4754
for line in fp:
4855
if line.startswith(" SCF Done"):
4956
# energies

dpdata/gromacs/gro.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
from __future__ import annotations
33

44
import re
5+
from typing import TYPE_CHECKING
56

67
import numpy as np
78

9+
from dpdata.utils import open_file
10+
11+
if TYPE_CHECKING:
12+
from dpdata.utils import FileType
13+
814
from ..unit import LengthConversion
915

1016
nm2ang = LengthConversion("nm", "angstrom").value()
@@ -48,9 +54,9 @@ def _get_cell(line):
4854
return cell
4955

5056

51-
def file_to_system_data(fname, format_atom_name=True, **kwargs):
57+
def file_to_system_data(fname: FileType, format_atom_name=True, **kwargs):
5258
system = {"coords": [], "cells": []}
53-
with open(fname) as fp:
59+
with open_file(fname) as fp:
5460
frame = 0
5561
while True:
5662
flag = fp.readline()

0 commit comments

Comments
 (0)