Skip to content

Commit 5559945

Browse files
committed
port test_tools to arraycontext
1 parent 9d5fdc6 commit 5559945

File tree

1 file changed

+53
-18
lines changed

1 file changed

+53
-18
lines changed

test/test_tools.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,49 @@
2020
THE SOFTWARE.
2121
"""
2222

23-
import logging
24-
logger = logging.getLogger(__name__)
23+
import pytest
24+
import sys
2525

26-
import sumpy.symbolic as sym
27-
from sumpy.tools import (fft_toeplitz_upper_triangular,
28-
matvec_toeplitz_upper_triangular, loopy_fft, fft)
2926
import numpy as np
3027

31-
import pyopencl as cl
32-
import pyopencl.array as cla
33-
from pyopencl.tools import ( # noqa
34-
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
28+
from arraycontext import pytest_generate_tests_for_array_contexts
29+
from sumpy.array_context import ( # noqa: F401
30+
PytestPyOpenCLArrayContextFactory, _acf)
3531

36-
import pytest
32+
import sumpy.symbolic as sym
33+
from sumpy.tools import (
34+
fft_toeplitz_upper_triangular,
35+
matvec_toeplitz_upper_triangular,
36+
loopy_fft,
37+
fft)
3738

39+
import logging
40+
logger = logging.getLogger(__name__)
41+
42+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
43+
PytestPyOpenCLArrayContextFactory,
44+
])
45+
46+
47+
# {{{ test_matvec_fft
3848

3949
def test_matvec_fft():
4050
k = 5
41-
v = np.random.rand(k)
42-
x = np.random.rand(k)
51+
52+
rng = np.random.default_rng(42)
53+
v = rng.random(k)
54+
x = rng.random(k)
4355

4456
fft = fft_toeplitz_upper_triangular(v, x)
4557
matvec = matvec_toeplitz_upper_triangular(v, x)
4658

4759
for i in range(k):
4860
assert abs(fft[i] - matvec[i]) < 1e-14
4961

62+
# }}}
63+
64+
65+
# {{{ test_matvec_fft_small_floats
5066

5167
def test_matvec_fft_small_floats():
5268
k = 5
@@ -60,15 +76,34 @@ def test_matvec_fft_small_floats():
6076
continue
6177
assert abs(f) > 1e-10
6278

79+
# }}}
80+
81+
82+
# {{{ test_fft
6383

6484
@pytest.mark.parametrize("size", [1, 2, 7, 10, 30, 210])
65-
def test_fft(ctx_factory, size):
66-
ctx = ctx_factory()
67-
queue = cl.CommandQueue(ctx)
85+
def test_fft(actx_factory, size):
86+
actx = actx_factory()
87+
6888
inp = np.arange(size, dtype=np.complex64)
69-
inp_dev = cla.to_device(queue, inp)
89+
inp_dev = actx.from_numpy(inp)
7090
out = fft(inp)
7191

7292
fft_func = loopy_fft(inp.shape, inverse=False, complex_dtype=inp.dtype.type)
73-
evt, (out_dev,) = fft_func(queue, y=inp_dev)
74-
assert np.allclose(out_dev.get(), out)
93+
evt, (out_dev,) = fft_func(actx.queue, y=inp_dev)
94+
95+
assert np.allclose(actx.to_numpy(out_dev), out)
96+
97+
# }}}
98+
99+
100+
# You can test individual routines by typing
101+
# $ python test_tools.py 'test_fft(_acf, 30)'
102+
103+
if __name__ == "__main__":
104+
if len(sys.argv) > 1:
105+
exec(sys.argv[1])
106+
else:
107+
pytest.main([__file__])
108+
109+
# vim: fdm=marker

0 commit comments

Comments
 (0)