Skip to content

Commit ed9e413

Browse files
committed
Uses complex dtypes in cufinufft py tests
Fixes #413.
1 parent 9842b32 commit ed9e413

File tree

3 files changed

+22
-32
lines changed

3 files changed

+22
-32
lines changed

python/cufinufft/tests/test_basic.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# NOTE: Tests below fail for tolerance 1e-4 (error executing plan).
1010

11-
DTYPES = [np.float32, np.float64]
11+
DTYPES = [np.complex64, np.complex128]
1212
SHAPES = [(16,), (16, 16), (16, 16, 16), (19,), (17, 19), (17, 19, 24)]
1313
MS = [256, 1024, 4096]
1414
TOLS = [1e-3, 1e-6]
@@ -24,22 +24,20 @@
2424
@pytest.mark.parametrize("output_arg", OUTPUT_ARGS)
2525
@pytest.mark.parametrize("modeord", MODEORDS)
2626
def test_type1(to_gpu, to_cpu, dtype, shape, M, tol, output_arg, modeord):
27-
complex_dtype = utils._complex_dtype(dtype)
28-
2927
k, c = utils.type1_problem(dtype, shape, M)
3028

3129
k_gpu = to_gpu(k)
3230
c_gpu = to_gpu(c)
3331

34-
plan = Plan(1, shape, eps=tol, dtype=complex_dtype, modeord=modeord)
32+
plan = Plan(1, shape, eps=tol, dtype=dtype, modeord=modeord)
3533

3634
# Since k_gpu is an array of shape (dim, M), this will expand to
3735
# plan.setpts(k_gpu[0], ..., k_gpu[dim]), allowing us to handle all
3836
# dimensions with the same call.
3937
plan.setpts(*k_gpu)
4038

4139
if output_arg:
42-
fk_gpu = _compat.array_empty_like(c_gpu, shape, dtype=complex_dtype)
40+
fk_gpu = _compat.array_empty_like(c_gpu, shape, dtype=dtype)
4341
plan.execute(c_gpu, out=fk_gpu)
4442
else:
4543
fk_gpu = plan.execute(c_gpu)
@@ -59,11 +57,9 @@ def test_type1(to_gpu, to_cpu, dtype, shape, M, tol, output_arg, modeord):
5957
@pytest.mark.parametrize("contiguous", CONTIGUOUS)
6058
@pytest.mark.parametrize("modeord", MODEORDS)
6159
def test_type2(to_gpu, to_cpu, dtype, shape, M, tol, output_arg, contiguous, modeord):
62-
complex_dtype = utils._complex_dtype(dtype)
63-
6460
k, fk = utils.type2_problem(dtype, shape, M)
6561

66-
plan = Plan(2, shape, eps=tol, dtype=complex_dtype, modeord=modeord)
62+
plan = Plan(2, shape, eps=tol, dtype=dtype, modeord=modeord)
6763

6864
check_result = True
6965

@@ -96,7 +92,7 @@ def _execute(*args, **kwargs):
9692
plan.setpts(*k_gpu)
9793

9894
if output_arg:
99-
c_gpu = _compat.array_empty_like(fk_gpu, (M,), dtype=complex_dtype)
95+
c_gpu = _compat.array_empty_like(fk_gpu, (M,), dtype=dtype)
10096
_execute(fk_gpu, out=c_gpu)
10197
else:
10298
c_gpu = _execute(fk_gpu)
@@ -119,12 +115,10 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
119115
# trigger it, we must run many other tests preceding this test case.
120116
# So it's related to some global state of the library.
121117

122-
complex_dtype = utils._complex_dtype(dtype)
123-
124-
source_pts, source_coefs, target_pts = utils.type3_problem(complex_dtype,
118+
source_pts, source_coefs, target_pts = utils.type3_problem(dtype,
125119
dim, n_source_pts, n_target_pts)
126120

127-
plan = Plan(3, dim, dtype=complex_dtype)
121+
plan = Plan(3, dim, dtype=dtype)
128122

129123
source_pts_gpu = to_gpu(source_pts)
130124
target_pts_gpu = to_gpu(target_pts)
@@ -137,7 +131,7 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
137131
target_coefs_gpu = plan.execute(source_coefs_gpu)
138132
else:
139133
target_coefs_gpu = _compat.array_empty_like(source_coefs_gpu,
140-
n_target_pts, dtype=complex_dtype)
134+
n_target_pts, dtype=dtype)
141135
plan.execute(source_coefs_gpu, out=target_coefs_gpu)
142136

143137
target_coefs = to_cpu(target_coefs_gpu)
@@ -146,17 +140,15 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
146140

147141

148142
def test_opts(to_gpu, to_cpu, shape=(8, 8, 8), M=32, tol=1e-3):
149-
dtype = np.float32
150-
151-
complex_dtype = utils._complex_dtype(dtype)
143+
dtype = np.complex64
152144

153145
k, c = utils.type1_problem(dtype, shape, M)
154146

155147
k_gpu = to_gpu(k)
156148
c_gpu = to_gpu(c)
157-
fk_gpu = _compat.array_empty_like(c_gpu, shape, dtype=complex_dtype)
149+
fk_gpu = _compat.array_empty_like(c_gpu, shape, dtype=dtype)
158150

159-
plan = Plan(1, shape, eps=tol, dtype=complex_dtype, gpu_sort=False,
151+
plan = Plan(1, shape, eps=tol, dtype=dtype, gpu_sort=False,
160152
gpu_maxsubprobsize=10)
161153

162154
plan.setpts(k_gpu[0], k_gpu[1], k_gpu[2])

python/cufinufft/tests/test_simple.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import utils
99

10-
DTYPES = [np.float32, np.float64]
10+
DTYPES = [np.complex64, np.complex128]
1111
SHAPES = [(16,), (16, 16), (16, 16, 16)]
1212
N_TRANS = [(), (1,), (2,)]
1313
MS = [256, 1024, 4096]
@@ -21,8 +21,7 @@
2121
@pytest.mark.parametrize("tol", TOLS)
2222
@pytest.mark.parametrize("output_arg", OUTPUT_ARGS)
2323
def test_simple_type1(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg):
24-
real_dtype = dtype
25-
complex_dtype = utils._complex_dtype(dtype)
24+
real_dtype = utils._real_dtype(dtype)
2625

2726
dim = len(shape)
2827

@@ -41,7 +40,7 @@ def test_simple_type1(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg)
4140
# batch, (1, N1, ...) for batch of size one, and (n, N1, ...) for
4241
# batch of size n.
4342
fk_gpu = _compat.array_empty_like(c_gpu, n_trans + shape,
44-
dtype=complex_dtype)
43+
dtype=dtype)
4544

4645
fun(*k_gpu, c_gpu, out=fk_gpu, eps=tol)
4746
else:
@@ -59,8 +58,7 @@ def test_simple_type1(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg)
5958
@pytest.mark.parametrize("tol", TOLS)
6059
@pytest.mark.parametrize("output_arg", OUTPUT_ARGS)
6160
def test_simple_type2(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg):
62-
real_dtype = dtype
63-
complex_dtype = utils._complex_dtype(dtype)
61+
real_dtype = utils._real_dtype(dtype)
6462

6563
dim = len(shape)
6664

@@ -75,7 +73,7 @@ def test_simple_type2(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg)
7573

7674
if output_arg:
7775
c_gpu = _compat.array_empty_like(fk_gpu, n_trans + (M,),
78-
dtype=complex_dtype)
76+
dtype=dtype)
7977

8078
fun(*k_gpu, fk_gpu, eps=tol, out=c_gpu)
8179
else:

python/cufinufft/tests/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ def gen_nonuniform_data(M, seed=0, n_trans=()):
4646

4747

4848
def type1_problem(dtype, shape, M, n_trans=()):
49-
complex_dtype = _complex_dtype(dtype)
49+
real_dtype = _real_dtype(dtype)
5050
dim = len(shape)
5151

52-
k = gen_nu_pts(M, dim=dim).astype(dtype)
53-
c = gen_nonuniform_data(M, n_trans=n_trans).astype(complex_dtype)
52+
k = gen_nu_pts(M, dim=dim).astype(real_dtype)
53+
c = gen_nonuniform_data(M, n_trans=n_trans).astype(dtype)
5454

5555
return k, c
5656

5757

5858
def type2_problem(dtype, shape, M, n_trans=()):
59-
complex_dtype = _complex_dtype(dtype)
59+
real_dtype = _real_dtype(dtype)
6060
dim = len(shape)
6161

62-
k = gen_nu_pts(M, dim=dim).astype(dtype)
63-
fk = gen_uniform_data(n_trans + shape).astype(complex_dtype)
62+
k = gen_nu_pts(M, dim=dim).astype(real_dtype)
63+
fk = gen_uniform_data(n_trans + shape).astype(dtype)
6464

6565
return k, fk
6666

0 commit comments

Comments
 (0)