Skip to content

Commit a7bf93d

Browse files
pref: lazy import modules (#658)
Fix #526. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Removed unnecessary import statements and restructured import handling for improved code organization and readability. - Reorganized imports within functions to localize dependencies and enhance code modularity. - **New Features** - Introduced conditional imports based on `TYPE_CHECKING` for better resource management and efficiency. - Added a new method `from_dict` to the `System` class for constructing instances from a data dictionary. - **Chores** - Updated linting rules in `pyproject.toml` to include `TID253` for banned module-level imports. - Modified import statements in test files to comply with the new linting rules for better code quality. - **Style** - Added `# noqa: TID253` comments to specific import statements to adhere to new linting rules and ensure clean code styling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 02309f7 commit a7bf93d

19 files changed

+128
-107
lines changed

dpdata/__init__.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,12 @@
1-
# monty needs lzma
2-
# See https://github.com/pandas-dev/pandas/pull/27882
3-
try:
4-
import lzma # noqa: F401
5-
except ImportError:
6-
7-
class fakemodule:
8-
pass
9-
10-
import sys
11-
12-
sys.modules["lzma"] = fakemodule
13-
141
from . import lammps, md, vasp
2+
from .bond_order_system import BondOrderSystem
153
from .system import LabeledSystem, MultiSystems, System
164

175
try:
186
from ._version import version as __version__
197
except ImportError:
208
from .__about__ import __version__
219

22-
# BondOrder System has dependency on rdkit
23-
try:
24-
# prevent conflict with dpdata.rdkit
25-
import rdkit as _ # noqa: F401
26-
27-
USE_RDKIT = True
28-
except ModuleNotFoundError:
29-
USE_RDKIT = False
30-
31-
if USE_RDKIT:
32-
from .bond_order_system import BondOrderSystem
33-
3410
__all__ = [
3511
"__version__",
3612
"lammps",

dpdata/amber/md.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33

44
import numpy as np
5-
from scipy.io import netcdf_file
65

76
from dpdata.amber.mask import pick_by_amber_mask
87
from dpdata.unit import EnergyConversion
@@ -44,6 +43,8 @@ def read_amber_traj(
4443
labeled : bool
4544
Whether to return labeled data
4645
"""
46+
from scipy.io import netcdf_file
47+
4748
flag_atom_type = False
4849
flag_atom_numb = False
4950
amber_types = []

dpdata/ase_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING, List, Optional
22

3-
from ase.calculators.calculator import (
3+
from ase.calculators.calculator import ( # noqa: TID253
44
Calculator,
55
PropertyNotImplementedError,
66
all_changes,

dpdata/bond_order_system.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from copy import deepcopy
44

55
import numpy as np
6-
from rdkit.Chem import Conformer
76

87
import dpdata.rdkit.utils
98
from dpdata.rdkit.sanitize import Sanitizer
@@ -102,6 +101,8 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs):
102101
return self
103102

104103
def to_fmt_obj(self, fmtobj, *args, **kwargs):
104+
from rdkit.Chem import Conformer
105+
105106
self.rdkit_mol.RemoveAllConformers()
106107
for ii in range(self.get_nframes()):
107108
conf = Conformer()

dpdata/deepmd/hdf5.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
from __future__ import annotations
44

55
import warnings
6+
from typing import TYPE_CHECKING
67

7-
try:
8-
import h5py
9-
except ImportError:
10-
pass
118
import numpy as np
12-
from wcmatch.glob import globfilter
139

1410
import dpdata
1511

12+
if TYPE_CHECKING:
13+
import h5py
14+
1615
__all__ = ["to_system_data", "dump"]
1716

1817

@@ -35,6 +34,8 @@ def to_system_data(
3534
labels : bool
3635
labels
3736
"""
37+
from wcmatch.glob import globfilter
38+
3839
g = f[folder] if folder else f
3940

4041
data = {}

dpdata/gaussian/gjf.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,7 @@
1010
from typing import List, Optional, Tuple, Union
1111

1212
import numpy as np
13-
from scipy.sparse import csr_matrix
14-
from scipy.sparse.csgraph import connected_components
1513

16-
try:
17-
from openbabel import openbabel
18-
except ImportError:
19-
try:
20-
import openbabel
21-
except ImportError:
22-
openbabel = None
2314
from dpdata.periodic_table import Element
2415

2516

@@ -53,10 +44,13 @@ def _crd2frag(symbols: List[str], crds: np.ndarray) -> Tuple[int, List[int]]:
5344
ImportError
5445
if Open Babel is not installed
5546
"""
56-
if openbabel is None:
57-
raise ImportError(
58-
"Open Babel (Python interface) should be installed to detect fragmentation!"
59-
)
47+
from scipy.sparse import csr_matrix
48+
from scipy.sparse.csgraph import connected_components
49+
50+
try:
51+
from openbabel import openbabel
52+
except ImportError:
53+
import openbabel
6054
atomnumber = len(symbols)
6155
# Use openbabel to connect atoms
6256
mol = openbabel.OBMol()

dpdata/periodic_table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import json
12
from pathlib import Path
23

3-
from monty.serialization import loadfn
4-
5-
fpdt = str(Path(__file__).absolute().parent / "periodic_table.json")
6-
_pdt = loadfn(fpdt)
4+
fpdt = Path(__file__).absolute().parent / "periodic_table.json"
5+
with fpdt.open("r") as fpdt:
6+
_pdt = json.load(fpdt)
77
ELEMENTS = [
88
"H",
99
"He",

dpdata/plugins/ase.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,9 @@
66
from dpdata.driver import Driver, Minimizer
77
from dpdata.format import Format
88

9-
try:
10-
import ase.io
11-
from ase.calculators.calculator import PropertyNotImplementedError
12-
from ase.io import Trajectory
13-
14-
if TYPE_CHECKING:
15-
from ase.optimize.optimize import Optimizer
16-
except ImportError:
17-
pass
9+
if TYPE_CHECKING:
10+
import ase
11+
from ase.optimize.optimize import Optimizer
1812

1913

2014
@Format.register("ase/structure")
@@ -84,6 +78,8 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
8478
ASE will raise RuntimeError if the atoms does not
8579
have a calculator
8680
"""
81+
from ase.calculators.calculator import PropertyNotImplementedError
82+
8783
info_dict = self.from_system(atoms)
8884
try:
8985
energies = atoms.get_potential_energy(force_consistent=True)
@@ -137,6 +133,8 @@ def from_multi_systems(
137133
ase.Atoms
138134
ASE atoms in the file
139135
"""
136+
import ase.io
137+
140138
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
141139
yield from frames
142140

@@ -222,6 +220,8 @@ def from_system(
222220
dict_frames: dict
223221
a dictionary containing data of multiple frames
224222
"""
223+
from ase.io import Trajectory
224+
225225
traj = Trajectory(file_name)
226226
sub_traj = traj[begin:end:step]
227227
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
@@ -264,6 +264,8 @@ def from_labeled_system(
264264
dict_frames: dict
265265
a dictionary containing data of multiple frames
266266
"""
267+
from ase.io import Trajectory
268+
267269
traj = Trajectory(file_name)
268270
sub_traj = traj[begin:end:step]
269271

dpdata/plugins/deepmd.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from __future__ import annotations
22

33
import os
4+
from typing import TYPE_CHECKING
45

5-
try:
6-
import h5py
7-
except ImportError:
8-
pass
96
import numpy as np
107

118
import dpdata
@@ -16,6 +13,9 @@
1613
from dpdata.driver import Driver
1714
from dpdata.format import Format
1815

16+
if TYPE_CHECKING:
17+
import h5py
18+
1919

2020
@Format.register("deepmd")
2121
@Format.register("deepmd/raw")
@@ -202,6 +202,8 @@ def _from_system(
202202
TypeError
203203
file_name is not str or h5py.Group or h5py.File
204204
"""
205+
import h5py
206+
205207
if isinstance(file_name, (h5py.Group, h5py.File)):
206208
return dpdata.deepmd.hdf5.to_system_data(
207209
file_name, "", type_map=type_map, labels=labels
@@ -300,6 +302,8 @@ def to_system(
300302
**kwargs : dict
301303
other parameters
302304
"""
305+
import h5py
306+
303307
if isinstance(file_name, (h5py.Group, h5py.File)):
304308
dpdata.deepmd.hdf5.dump(
305309
file_name, "", data, set_size=set_size, comp_prec=comp_prec
@@ -330,6 +334,8 @@ def from_multi_systems(self, directory: str, **kwargs) -> h5py.Group:
330334
h5py.Group
331335
a HDF5 group in the HDF5 file
332336
"""
337+
import h5py
338+
333339
with h5py.File(directory, "r") as f:
334340
for ff in f.keys():
335341
yield f[ff]
@@ -353,6 +359,8 @@ def to_multi_systems(
353359
h5py.Group
354360
a HDF5 group with the name of formula
355361
"""
362+
import h5py
363+
356364
with h5py.File(directory, "w") as f:
357365
for ff in formulas:
358366
yield f.create_group(ff)

dpdata/plugins/rdkit.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1+
import dpdata.rdkit.utils
12
from dpdata.format import Format
23

3-
try:
4-
import rdkit.Chem
5-
6-
import dpdata.rdkit.utils
7-
except ModuleNotFoundError:
8-
pass
9-
104

115
@Format.register("mol")
126
@Format.register("mol_file")
137
class MolFormat(Format):
148
def from_bond_order_system(self, file_name, **kwargs):
9+
import rdkit.Chem
10+
1511
return rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False)
1612

1713
def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs):
14+
import rdkit.Chem
15+
1816
assert frame_idx < mol.GetNumConformers()
1917
rdkit.Chem.MolToMolFile(mol, file_name, confId=frame_idx)
2018

@@ -24,6 +22,8 @@ def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs):
2422
class SdfFormat(Format):
2523
def from_bond_order_system(self, file_name, **kwargs):
2624
"""Note that it requires all molecules in .sdf file must be of the same topology."""
25+
import rdkit.Chem
26+
2727
mols = [
2828
m
2929
for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False)
@@ -35,6 +35,8 @@ def from_bond_order_system(self, file_name, **kwargs):
3535
return mol
3636

3737
def to_bond_order_system(self, data, mol, file_name, frame_idx=-1, **kwargs):
38+
import rdkit.Chem
39+
3840
sdf_writer = rdkit.Chem.SDWriter(file_name)
3941
if frame_idx == -1:
4042
for ii in range(mol.GetNumConformers()):

0 commit comments

Comments
 (0)