1212]
1313
1414import os
15- from importlib import util
16-
17- # check package availability
18- cupy_enabled = (
19- util .find_spec ("cupy" ) is not None and int (os .getenv ("CUPY_PYLOPS" , 1 )) == 1
20- )
21- cusignal_enabled = (
22- util .find_spec ("cusignal" ) is not None and int (os .getenv ("CUSIGNAL_PYLOPS" , 1 )) == 1
23- )
24- devito_enabled = util .find_spec ("devito" ) is not None
25- numba_enabled = util .find_spec ("numba" ) is not None
26- pyfftw_enabled = util .find_spec ("pyfftw" ) is not None
27- pywt_enabled = util .find_spec ("pywt" ) is not None
28- skfmm_enabled = util .find_spec ("skfmm" ) is not None
29- spgl1_enabled = util .find_spec ("spgl1" ) is not None
30- sympy_enabled = util .find_spec ("sympy" ) is not None
31- torch_enabled = util .find_spec ("torch" ) is not None
15+ from importlib import import_module , util
16+ from typing import Optional
3217
3318
3419# error message at import of available package
35- def devito_import (message ):
20+ def cupy_import (message : Optional [str ] = None ) -> str :
21+ # detect if cupy is available and the user is expecting to be used
22+ cupy_test = (
23+ util .find_spec ("cupy" ) is not None and int (os .getenv ("CUPY_PYLOPS" , 1 )) == 1
24+ )
25+ # if cupy should be importable
26+ if cupy_test :
27+ # try importing it
28+ try :
29+ import_module ("cupy" ) # noqa: F401
30+
31+ # if successful set the message to None.
32+ cupy_message = None
33+ # if unable to import but the package is installed
34+ except (ImportError , ModuleNotFoundError ) as e :
35+ cupy_message = (
36+ f"Failed to import cupy, Falling back to CPU (error: { e } ). "
37+ "Please ensure your CUDA environment is set up correctly "
38+ "for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
39+ )
40+ print (UserWarning (cupy_message ))
41+ # if cupy_test is False, it means not installed or environment variable set to 0
42+ else :
43+ cupy_message = (
44+ "Cupy package not installed or os.getenv('CUPY_PYLOPS') == 0. "
45+ f"In order to be able to use { message } "
46+ "ensure 'os.getenv('CUPY_PYLOPS') == 1' and run "
47+ "'pip install cupy'; "
48+ "for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
49+ )
50+
51+ return cupy_message
52+
53+
54+ def cusignal_import (message : Optional [str ] = None ) -> str :
55+ cusignal_test = (
56+ util .find_spec ("cusignal" ) is not None
57+ and int (os .getenv ("CUSIGNAL_PYLOPS" , 1 )) == 1
58+ )
59+ if cusignal_test :
60+ try :
61+ import_module ("cusignal" ) # noqa: F401
62+
63+ cusignal_message = None
64+ except (ImportError , ModuleNotFoundError ) as e :
65+ cusignal_message = (
66+ f"Failed to import cusignal. Falling back to CPU (error: { e } ) . "
67+ "Please ensure your CUDA environment is set up correctly; "
68+ "for more details visit 'https://github.com/rapidsai/cusignal#installation'"
69+ )
70+ print (UserWarning (cusignal_message ))
71+ else :
72+ cusignal_message = (
73+ "Cusignal not installed or os.getenv('CUSIGNAL_PYLOPS') == 0. "
74+ f"In order to be able to use { message } "
75+ "ensure 'os.getenv('CUSIGNAL_PYLOPS') == 1' and run "
76+ "'conda install cusignal'; "
77+ "for more details visit ''https://github.com/rapidsai/cusignal#installation''"
78+ )
79+
80+ return cusignal_message
81+
82+
83+ def devito_import (message : Optional [str ] = None ) -> str :
3684 if devito_enabled :
3785 try :
38- import devito # noqa: F401
86+ import_module ( " devito" ) # noqa: F401
3987
4088 devito_message = None
4189 except Exception as e :
@@ -49,10 +97,10 @@ def devito_import(message):
4997 return devito_message
5098
5199
52- def numba_import (message ) :
100+ def numba_import (message : Optional [ str ] = None ) -> str :
53101 if numba_enabled :
54102 try :
55- import numba # noqa: F401
103+ import_module ( " numba" ) # noqa: F401
56104
57105 numba_message = None
58106 except Exception as e :
@@ -68,10 +116,10 @@ def numba_import(message):
68116 return numba_message
69117
70118
71- def pyfftw_import (message ) :
119+ def pyfftw_import (message : Optional [ str ] = None ) -> str :
72120 if pyfftw_enabled :
73121 try :
74- import pyfftw # noqa: F401
122+ import_module ( " pyfftw" ) # noqa: F401
75123
76124 pyfftw_message = None
77125 except Exception as e :
@@ -87,10 +135,10 @@ def pyfftw_import(message):
87135 return pyfftw_message
88136
89137
90- def pywt_import (message ) :
138+ def pywt_import (message : Optional [ str ] = None ) -> str :
91139 if pywt_enabled :
92140 try :
93- import pywt # noqa: F401
141+ import_module ( " pywt" ) # noqa: F401
94142
95143 pywt_message = None
96144 except Exception as e :
@@ -106,10 +154,10 @@ def pywt_import(message):
106154 return pywt_message
107155
108156
109- def skfmm_import (message ) :
157+ def skfmm_import (message : Optional [ str ] = None ) -> str :
110158 if skfmm_enabled :
111159 try :
112- import skfmm # noqa: F401
160+ import_module ( " skfmm" ) # noqa: F401
113161
114162 skfmm_message = None
115163 except Exception as e :
@@ -124,10 +172,10 @@ def skfmm_import(message):
124172 return skfmm_message
125173
126174
127- def spgl1_import (message ) :
175+ def spgl1_import (message : Optional [ str ] = None ) -> str :
128176 if spgl1_enabled :
129177 try :
130- import spgl1 # noqa: F401
178+ import_module ( " spgl1" ) # noqa: F401
131179
132180 spgl1_message = None
133181 except Exception as e :
@@ -141,10 +189,10 @@ def spgl1_import(message):
141189 return spgl1_message
142190
143191
144- def sympy_import (message ) :
192+ def sympy_import (message : Optional [ str ] = None ) -> str :
145193 if sympy_enabled :
146194 try :
147- import sympy # noqa: F401
195+ import_module ( " sympy" ) # noqa: F401
148196
149197 sympy_message = None
150198 except Exception as e :
@@ -156,3 +204,28 @@ def sympy_import(message):
156204 f'"pip install sympy".'
157205 )
158206 return sympy_message
207+
208+
209+ # Set package availability booleans
210+ # cupy and cusignal: the package is imported to check everything is working correctly,
211+ # if not the package is disabled. We do this here as both libraries are used as drop-in
212+ # replacement for many numpy and scipy routines when cupy arrays are provided.
213+ # all other libraries: we simply check if the package is available and postpone its import
214+ # to check everything is working correctly when a user tries to create an operator that requires
215+ # such a package
216+ cupy_enabled : bool = (
217+ True if (cupy_import () is None and int (os .getenv ("CUPY_PYLOPS" , 1 )) == 1 ) else False
218+ )
219+ cusignal_enabled : bool = (
220+ True
221+ if (cusignal_import () is None and int (os .getenv ("CUSIGNAL_PYLOPS" , 1 )) == 1 )
222+ else False
223+ )
224+ devito_enabled = util .find_spec ("devito" ) is not None
225+ numba_enabled = util .find_spec ("numba" ) is not None
226+ pyfftw_enabled = util .find_spec ("pyfftw" ) is not None
227+ pywt_enabled = util .find_spec ("pywt" ) is not None
228+ skfmm_enabled = util .find_spec ("skfmm" ) is not None
229+ spgl1_enabled = util .find_spec ("spgl1" ) is not None
230+ sympy_enabled = util .find_spec ("sympy" ) is not None
231+ torch_enabled = util .find_spec ("torch" ) is not None
0 commit comments