2727# cython: language_level=3
2828
2929import mkl_umath._ufuncs as mu
30- import numpy.core.umath as nu
3130
3231cimport numpy as cnp
3332import numpy as np
@@ -59,15 +58,15 @@ cdef class patch:
5958 self .functions_count = 0
6059 for umath in umaths:
6160 mkl_umath = getattr (mu, umath)
62- self .functions_count = self .functions_count + mkl_umath.ntypes
61+ self .functions_count += mkl_umath.ntypes
6362
6463 self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
6564
6665 func_number = 0
6766 for umath in umaths:
6867 patch_umath = getattr (mu, umath)
6968 c_patch_umath = < cnp.ufunc> patch_umath
70- c_orig_umath = < cnp.ufunc> getattr (nu , umath)
69+ c_orig_umath = < cnp.ufunc> getattr (np , umath)
7170 nargs = c_patch_umath.nargs
7271 for pi in range (c_patch_umath.ntypes):
7372 oi = 0
@@ -103,7 +102,7 @@ cdef class patch:
103102 cdef int * signature
104103
105104 for func in self .functions_dict:
106- np_umath = getattr (nu , func[0 ])
105+ np_umath = getattr (np , func[0 ])
107106 index = self .functions_dict[func]
108107 function = self .functions[index].patch_function
109108 signature = self .functions[index].signature
@@ -118,7 +117,7 @@ cdef class patch:
118117 cdef int * signature
119118
120119 for func in self .functions_dict:
121- np_umath = getattr (nu , func[0 ])
120+ np_umath = getattr (np , func[0 ])
122121 index = self .functions_dict[func]
123122 function = self .functions[index].original_function
124123 signature = self .functions[index].signature
@@ -143,34 +142,97 @@ def _initialize_tls():
143142
144143
145144def use_in_numpy ():
146- '''
145+ """
147146 Enables using of mkl_umath in Numpy.
148- '''
147+
148+ Examples
149+ --------
150+ >>> import mkl_umath, numpy as np
151+ >>> mkl_umath.is_patched()
152+ # False
153+
154+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
155+ >>> mkl_umath.is_patched()
156+ # True
157+
158+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
159+ >>> mkl_umath.is_patched()
160+ # False
161+
162+ """
149163 if not _is_tls_initialized():
150164 _initialize_tls()
151165 _tls.patch.do_patch()
152166
153167
154168def restore ():
155- '''
169+ """
156170 Disables using of mkl_umath in Numpy.
157- '''
171+
172+ Examples
173+ --------
174+ >>> import mkl_umath, numpy as np
175+ >>> mkl_umath.is_patched()
176+ # False
177+
178+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
179+ >>> mkl_umath.is_patched()
180+ # True
181+
182+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
183+ >>> mkl_umath.is_patched()
184+ # False
185+
186+ """
158187 if not _is_tls_initialized():
159188 _initialize_tls()
160189 _tls.patch.do_unpatch()
161190
162191
163192def is_patched ():
164- '''
193+ """
165194 Returns whether Numpy has been patched with mkl_umath.
166- '''
195+
196+ Examples
197+ --------
198+ >>> import mkl_umath, numpy as np
199+ >>> mkl_umath.is_patched()
200+ # False
201+
202+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
203+ >>> mkl_umath.is_patched()
204+ # True
205+
206+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
207+ >>> mkl_umath.is_patched()
208+ # False
209+
210+ """
167211 if not _is_tls_initialized():
168212 _initialize_tls()
169- _tls.patch.is_patched()
213+ return _tls.patch.is_patched()
170214
171215from contextlib import ContextDecorator
172216
173217class mkl_umath (ContextDecorator ):
218+ """
219+ Context manager and decorator to temporarily patch NumPy ufuncs
220+ with MKL-based implementations.
221+
222+ Examples
223+ --------
224+ >>> import mkl_umath, numpy as np
225+ >>> mkl_umath.is_patched()
226+ # False
227+
228+ >>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy
229+ >>> print(mkl_umath.is_patched())
230+ # True
231+
232+ >>> mkl_umath.is_patched()
233+ # False
234+
235+ """
174236 def __enter__ (self ):
175237 use_in_numpy()
176238 return self
0 commit comments