Skip to content

Commit b00ead3

Browse files
committed
fix issues with set_workers
1 parent 05232e0 commit b00ead3

File tree

4 files changed

+51
-41
lines changed

4 files changed

+51
-41
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
1616
* To set `mkl_fft` as the backend for SciPy is only possible through `mkl_fft.interfaces.scipy_fft` [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
1717

18+
### Fixed
19+
* Fixed issues with `set_workers` function in SciPy interface `mkl_fft.interfaces.scipy_fft` [gh-183](https://github.com/IntelPython/mkl_fft/pull/183)
20+
1821
## [1.3.14] (04/10/2025)
1922

2023
resolves gh-152 by adding an explicit `mkl-service` dependency to `mkl-fft` when building the wheel

mkl_fft/interfaces/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ This interface is a drop-in replacement for the [`scipy.fft`](https://scipy.gith
4242

4343
* Helper functions: `fftshift`, `ifftshift`, `fftfreq`, `rfftfreq`, `set_workers`, `get_workers`. All of these functions, except for `set_workers` and `get_workers`, serve as a fallback to the SciPy implementation and are included for completeness.
4444

45+
Note that in computing FFTs, the default value of `workers` parameter is the maximum number of threads available unlike the default behavior of SciPy where only one thread is used.
46+
4547
The following example shows how to use this interface for calculating a 1D FFT.
4648

4749
```python
@@ -102,3 +104,24 @@ with scipy.fft.set_backend(mkl_backend, only=True):
102104
print(f"Time with OneMKL FFT backend installed: {t2:.1f} seconds")
103105
# Time with MKL FFT backend installed: 9.1 seconds
104106
```
107+
108+
In the following example, we use `set_worker` to control the number of threads when `mkl_fft` is being used as a backend for SciPy.
109+
110+
```python
111+
import numpy, mkl, scipy
112+
import mkl_fft.interfaces.scipy_fft as mkl_fft
113+
import scipy
114+
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
115+
scipy.fft.set_global_backend(mkl_fft) # set mkl_fft as global backend
116+
117+
mkl.verbose(1)
118+
# True
119+
mkl.get_max_threads()
120+
# 112
121+
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:112
122+
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:56,unaligned_input,unaligned_output,desc:0x563aefe86180) 165.02us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
123+
124+
with mkl_fft.set_workers(4):
125+
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4
126+
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4
127+
```

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -66,31 +66,13 @@
6666
]
6767

6868

69-
class _cpu_max_threads_count:
70-
def __init__(self):
71-
self.cpu_count = None
72-
self.max_threads_count = None
73-
74-
def get_cpu_count(self):
75-
if self.cpu_count is None:
76-
max_threads = self.get_max_threads_count()
77-
self.cpu_count = max_threads
78-
return self.cpu_count
79-
80-
def get_max_threads_count(self):
81-
if self.max_threads_count is None:
82-
# pylint: disable=no-member
83-
self.max_threads_count = mkl.get_max_threads()
84-
85-
return self.max_threads_count
86-
87-
8869
class _workers_data:
8970
def __init__(self, workers=None):
90-
if workers:
91-
self.workers_ = workers
71+
if workers is not None: # workers = 0 should be handled
72+
self.workers_ = _workers_to_num_threads(workers)
9273
else:
93-
self.workers_ = _cpu_max_threads_count().get_cpu_count()
74+
# Unlike SciPy, the default value is maximum number of threads
75+
self.workers_ = mkl.get_max_threads() # pylint: disable=no-member
9476
self.workers_ = operator.index(self.workers_)
9577

9678
@property
@@ -108,21 +90,22 @@ def workers(self, workers_val):
10890

10991

11092
def _workers_to_num_threads(w):
111-
"""Handle conversion of workers to a positive number of threads in the
112-
same way as scipy.fft.helpers._workers.
93+
"""
94+
Handle conversion of workers to a positive number of threads in the
95+
same way as scipy.fft._pocketfft.helpers._workers.
11396
"""
11497
if w is None:
11598
return _workers_global_settings.get().workers
11699
_w = operator.index(w)
117100
if _w == 0:
118101
raise ValueError("Number of workers must not be zero")
119102
if _w < 0:
120-
ub = os.cpu_count()
121-
_w += ub + 1
103+
_cpu_count = os.cpu_count()
104+
_w += _cpu_count + 1
122105
if _w <= 0:
123106
raise ValueError(
124-
"workers value out of range; got {}, must not be"
125-
" less than {}".format(w, -ub)
107+
f"workers value out of range; got {w}, must not be less "
108+
f"than {-_cpu_count}"
126109
)
127110
return _w
128111

@@ -134,14 +117,16 @@ def __init__(self, workers):
134117

135118
def __enter__(self):
136119
try:
120+
# mkl.set_num_threads_local sets the number of threads to the
121+
# given input number, and returns the previous number of threads
137122
# pylint: disable=no-member
138123
self.prev_num_threads = mkl.set_num_threads_local(self.n_threads)
139124
except Exception as e:
140125
raise ValueError(
141-
"Class argument {} result in invalid number of threads {}".format(
142-
self.workers, self.n_threads
143-
)
126+
f"Class argument {self.workers} results in invalid number of "
127+
f"threads {self.n_threads}"
144128
) from e
129+
return self
145130

146131
def __exit__(self, *args):
147132
# restore old value
@@ -696,21 +681,19 @@ def get_workers():
696681

697682

698683
@contextlib.contextmanager
699-
def set_workers(n_workers):
684+
def set_workers(workers):
700685
"""
701686
Set the value of workers used by default, returns the previous value.
702687
703688
For full documentation refer to `scipy.fft.set_workers`.
704689
705690
"""
706-
nw = operator.index(n_workers)
691+
nw = operator.index(workers)
707692
token = None
708693
try:
709694
new_wd = _workers_data(nw)
710695
token = _workers_global_settings.set(new_wd)
711696
yield
712697
finally:
713-
if token:
698+
if token is not None:
714699
_workers_global_settings.reset(token)
715-
else:
716-
raise ValueError

mkl_fft/tests/third_party/scipy/test_multithreading.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _mt_fft(x):
5252
return fft.fft(x, workers=2)
5353

5454

55-
@pytest.mark.slow
55+
# @pytest.mark.slow
5656
def test_mixed_threads_processes(x):
5757
# Test that the fft threadpool is safe to use before & after fork
5858

@@ -79,10 +79,11 @@ def test_invalid_workers(x):
7979
fft.ifft(x, workers=-cpus - 1)
8080

8181

82-
@pytest.mark.skip()
8382
def test_set_get_workers():
8483
cpus = os.cpu_count()
85-
assert fft.get_workers() == 1
84+
85+
# default value is max number of threads unlike stock SciPy
86+
assert fft.get_workers() == cpus
8687
with fft.set_workers(4):
8788
assert fft.get_workers() == 4
8889

@@ -91,13 +92,13 @@ def test_set_get_workers():
9192

9293
assert fft.get_workers() == 4
9394

94-
assert fft.get_workers() == 1
95+
# default value is max number of threads unlike stock SciPy
96+
assert fft.get_workers() == cpus
9597

9698
with fft.set_workers(-cpus):
9799
assert fft.get_workers() == 1
98100

99101

100-
@pytest.mark.skip("mkl_fft does not validate workers")
101102
def test_set_workers_invalid():
102103

103104
with pytest.raises(ValueError, match="workers must not be zero"):

0 commit comments

Comments
 (0)