Skip to content

Commit 77eca0f

Browse files
Changing how optional dependencies are import state are checked (#551)
* adding initial check * adding check_module_enabled func * silly typo "module" instead of module * switch all optionals to check_module_enabled * changing to c/longdouble for tests * cleaning up old code and adding noqa * Revert "changing to c/longdouble for tests" This reverts commit 06534b1. * numpydoc style docstring * adding cupy/cusignal import tests and how enables are set * adding message to be optional * fixing logic in cupy/cusignal_enabled * cleaning up old code snippets and typos * small format changes * more formatting changes * changing import to import_module for pylint * minor: restyling deps.py * minor: added alex-rakowski to contributors --------- Co-authored-by: mrava87 <[email protected]>
1 parent d6e484c commit 77eca0f

File tree

3 files changed

+107
-32
lines changed

3 files changed

+107
-32
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,4 @@ A list of video tutorials to learn more about PyLops:
149149
* Rohan Babbar, rohanbabbar04
150150
* Wei Zhang, ZhangWeiGeo
151151
* Fedor Goncharov, fedor-goncharov
152+
* Alex Rakowski, alex-rakowski

docs/source/credits.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ Contributors
2020
* `Aniket Singh Rawat <https://github.com/dikwickley>`_, dikwickley
2121
* `Rohan Babbar <https://github.com/rohanbabbar04>`_, rohanbabbar04
2222
* `Wei Zhang <https://github.com/ZhangWeiGeo>`_, ZhangWeiGeo
23-
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
23+
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
24+
* `Alex Rakowski <https://github.com/alex-rakowski>`_, alex-rakowski

pylops/utils/deps.py

Lines changed: 104 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,78 @@
1212
]
1313

1414
import os
15-
from importlib import util
16-
17-
# check package availability
18-
cupy_enabled = (
19-
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
20-
)
21-
cusignal_enabled = (
22-
util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
23-
)
24-
devito_enabled = util.find_spec("devito") is not None
25-
numba_enabled = util.find_spec("numba") is not None
26-
pyfftw_enabled = util.find_spec("pyfftw") is not None
27-
pywt_enabled = util.find_spec("pywt") is not None
28-
skfmm_enabled = util.find_spec("skfmm") is not None
29-
spgl1_enabled = util.find_spec("spgl1") is not None
30-
sympy_enabled = util.find_spec("sympy") is not None
31-
torch_enabled = util.find_spec("torch") is not None
15+
from importlib import import_module, util
16+
from typing import Optional
3217

3318

3419
# error message at import of available package
35-
def devito_import(message):
20+
def cupy_import(message: Optional[str] = None) -> str:
21+
# detect if cupy is available and the user is expecting to be used
22+
cupy_test = (
23+
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
24+
)
25+
# if cupy should be importable
26+
if cupy_test:
27+
# try importing it
28+
try:
29+
import_module("cupy") # noqa: F401
30+
31+
# if successful set the message to None.
32+
cupy_message = None
33+
# if unable to import but the package is installed
34+
except (ImportError, ModuleNotFoundError) as e:
35+
cupy_message = (
36+
f"Failed to import cupy, Falling back to CPU (error: {e}). "
37+
"Please ensure your CUDA environment is set up correctly "
38+
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
39+
)
40+
print(UserWarning(cupy_message))
41+
# if cupy_test is False, it means not installed or environment variable set to 0
42+
else:
43+
cupy_message = (
44+
"Cupy package not installed or os.getenv('CUPY_PYLOPS') == 0. "
45+
f"In order to be able to use {message} "
46+
"ensure 'os.getenv('CUPY_PYLOPS') == 1' and run "
47+
"'pip install cupy'; "
48+
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
49+
)
50+
51+
return cupy_message
52+
53+
54+
def cusignal_import(message: Optional[str] = None) -> str:
55+
cusignal_test = (
56+
util.find_spec("cusignal") is not None
57+
and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
58+
)
59+
if cusignal_test:
60+
try:
61+
import_module("cusignal") # noqa: F401
62+
63+
cusignal_message = None
64+
except (ImportError, ModuleNotFoundError) as e:
65+
cusignal_message = (
66+
f"Failed to import cusignal. Falling back to CPU (error: {e}) . "
67+
"Please ensure your CUDA environment is set up correctly; "
68+
"for more details visit 'https://github.com/rapidsai/cusignal#installation'"
69+
)
70+
print(UserWarning(cusignal_message))
71+
else:
72+
cusignal_message = (
73+
"Cusignal not installed or os.getenv('CUSIGNAL_PYLOPS') == 0. "
74+
f"In order to be able to use {message} "
75+
"ensure 'os.getenv('CUSIGNAL_PYLOPS') == 1' and run "
76+
"'conda install cusignal'; "
77+
"for more details visit ''https://github.com/rapidsai/cusignal#installation''"
78+
)
79+
80+
return cusignal_message
81+
82+
83+
def devito_import(message: Optional[str] = None) -> str:
3684
if devito_enabled:
3785
try:
38-
import devito # noqa: F401
86+
import_module("devito") # noqa: F401
3987

4088
devito_message = None
4189
except Exception as e:
@@ -49,10 +97,10 @@ def devito_import(message):
4997
return devito_message
5098

5199

52-
def numba_import(message):
100+
def numba_import(message: Optional[str] = None) -> str:
53101
if numba_enabled:
54102
try:
55-
import numba # noqa: F401
103+
import_module("numba") # noqa: F401
56104

57105
numba_message = None
58106
except Exception as e:
@@ -68,10 +116,10 @@ def numba_import(message):
68116
return numba_message
69117

70118

71-
def pyfftw_import(message):
119+
def pyfftw_import(message: Optional[str] = None) -> str:
72120
if pyfftw_enabled:
73121
try:
74-
import pyfftw # noqa: F401
122+
import_module("pyfftw") # noqa: F401
75123

76124
pyfftw_message = None
77125
except Exception as e:
@@ -87,10 +135,10 @@ def pyfftw_import(message):
87135
return pyfftw_message
88136

89137

90-
def pywt_import(message):
138+
def pywt_import(message: Optional[str] = None) -> str:
91139
if pywt_enabled:
92140
try:
93-
import pywt # noqa: F401
141+
import_module("pywt") # noqa: F401
94142

95143
pywt_message = None
96144
except Exception as e:
@@ -106,10 +154,10 @@ def pywt_import(message):
106154
return pywt_message
107155

108156

109-
def skfmm_import(message):
157+
def skfmm_import(message: Optional[str] = None) -> str:
110158
if skfmm_enabled:
111159
try:
112-
import skfmm # noqa: F401
160+
import_module("skfmm") # noqa: F401
113161

114162
skfmm_message = None
115163
except Exception as e:
@@ -124,10 +172,10 @@ def skfmm_import(message):
124172
return skfmm_message
125173

126174

127-
def spgl1_import(message):
175+
def spgl1_import(message: Optional[str] = None) -> str:
128176
if spgl1_enabled:
129177
try:
130-
import spgl1 # noqa: F401
178+
import_module("spgl1") # noqa: F401
131179

132180
spgl1_message = None
133181
except Exception as e:
@@ -141,10 +189,10 @@ def spgl1_import(message):
141189
return spgl1_message
142190

143191

144-
def sympy_import(message):
192+
def sympy_import(message: Optional[str] = None) -> str:
145193
if sympy_enabled:
146194
try:
147-
import sympy # noqa: F401
195+
import_module("sympy") # noqa: F401
148196

149197
sympy_message = None
150198
except Exception as e:
@@ -156,3 +204,28 @@ def sympy_import(message):
156204
f'"pip install sympy".'
157205
)
158206
return sympy_message
207+
208+
209+
# Set package availability booleans
210+
# cupy and cusignal: the package is imported to check everything is working correctly,
211+
# if not the package is disabled. We do this here as both libraries are used as drop-in
212+
# replacement for many numpy and scipy routines when cupy arrays are provided.
213+
# all other libraries: we simply check if the package is available and postpone its import
214+
# to check everything is working correctly when a user tries to create an operator that requires
215+
# such a package
216+
cupy_enabled: bool = (
217+
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
218+
)
219+
cusignal_enabled: bool = (
220+
True
221+
if (cusignal_import() is None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1)
222+
else False
223+
)
224+
devito_enabled = util.find_spec("devito") is not None
225+
numba_enabled = util.find_spec("numba") is not None
226+
pyfftw_enabled = util.find_spec("pyfftw") is not None
227+
pywt_enabled = util.find_spec("pywt") is not None
228+
skfmm_enabled = util.find_spec("skfmm") is not None
229+
spgl1_enabled = util.find_spec("spgl1") is not None
230+
sympy_enabled = util.find_spec("sympy") is not None
231+
torch_enabled = util.find_spec("torch") is not None

0 commit comments

Comments
 (0)