88
99# NOTE: Tests below fail for tolerance 1e-4 (error executing plan).
1010
11- DTYPES = [np .float32 , np .float64 ]
11+ DTYPES = [np .complex64 , np .complex128 ]
1212SHAPES = [(16 ,), (16 , 16 ), (16 , 16 , 16 ), (19 ,), (17 , 19 ), (17 , 19 , 24 )]
1313MS = [256 , 1024 , 4096 ]
1414TOLS = [1e-3 , 1e-6 ]
2424@pytest .mark .parametrize ("output_arg" , OUTPUT_ARGS )
2525@pytest .mark .parametrize ("modeord" , MODEORDS )
2626def test_type1 (to_gpu , to_cpu , dtype , shape , M , tol , output_arg , modeord ):
27- complex_dtype = utils ._complex_dtype (dtype )
28-
2927 k , c = utils .type1_problem (dtype , shape , M )
3028
3129 k_gpu = to_gpu (k )
3230 c_gpu = to_gpu (c )
3331
34- plan = Plan (1 , shape , eps = tol , dtype = complex_dtype , modeord = modeord )
32+ plan = Plan (1 , shape , eps = tol , dtype = dtype , modeord = modeord )
3533
3634 # Since k_gpu is an array of shape (dim, M), this will expand to
3735 # plan.setpts(k_gpu[0], ..., k_gpu[dim]), allowing us to handle all
3836 # dimensions with the same call.
3937 plan .setpts (* k_gpu )
4038
4139 if output_arg :
42- fk_gpu = _compat .array_empty_like (c_gpu , shape , dtype = complex_dtype )
40+ fk_gpu = _compat .array_empty_like (c_gpu , shape , dtype = dtype )
4341 plan .execute (c_gpu , out = fk_gpu )
4442 else :
4543 fk_gpu = plan .execute (c_gpu )
@@ -59,11 +57,9 @@ def test_type1(to_gpu, to_cpu, dtype, shape, M, tol, output_arg, modeord):
5957@pytest .mark .parametrize ("contiguous" , CONTIGUOUS )
6058@pytest .mark .parametrize ("modeord" , MODEORDS )
6159def test_type2 (to_gpu , to_cpu , dtype , shape , M , tol , output_arg , contiguous , modeord ):
62- complex_dtype = utils ._complex_dtype (dtype )
63-
6460 k , fk = utils .type2_problem (dtype , shape , M )
6561
66- plan = Plan (2 , shape , eps = tol , dtype = complex_dtype , modeord = modeord )
62+ plan = Plan (2 , shape , eps = tol , dtype = dtype , modeord = modeord )
6763
6864 check_result = True
6965
@@ -96,7 +92,7 @@ def _execute(*args, **kwargs):
9692 plan .setpts (* k_gpu )
9793
9894 if output_arg :
99- c_gpu = _compat .array_empty_like (fk_gpu , (M ,), dtype = complex_dtype )
95+ c_gpu = _compat .array_empty_like (fk_gpu , (M ,), dtype = dtype )
10096 _execute (fk_gpu , out = c_gpu )
10197 else :
10298 c_gpu = _execute (fk_gpu )
@@ -119,12 +115,10 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
119115 # trigger it, we must run many other tests preceding this test case.
120116 # So it's related to some global state of the library.
121117
122- complex_dtype = utils ._complex_dtype (dtype )
123-
124- source_pts , source_coefs , target_pts = utils .type3_problem (complex_dtype ,
118+ source_pts , source_coefs , target_pts = utils .type3_problem (dtype ,
125119 dim , n_source_pts , n_target_pts )
126120
127- plan = Plan (3 , dim , dtype = complex_dtype )
121+ plan = Plan (3 , dim , dtype = dtype )
128122
129123 source_pts_gpu = to_gpu (source_pts )
130124 target_pts_gpu = to_gpu (target_pts )
@@ -137,7 +131,7 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
137131 target_coefs_gpu = plan .execute (source_coefs_gpu )
138132 else :
139133 target_coefs_gpu = _compat .array_empty_like (source_coefs_gpu ,
140- n_target_pts , dtype = complex_dtype )
134+ n_target_pts , dtype = dtype )
141135 plan .execute (source_coefs_gpu , out = target_coefs_gpu )
142136
143137 target_coefs = to_cpu (target_coefs_gpu )
@@ -146,17 +140,15 @@ def test_type3(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, output_ar
146140
147141
148142def test_opts (to_gpu , to_cpu , shape = (8 , 8 , 8 ), M = 32 , tol = 1e-3 ):
149- dtype = np .float32
150-
151- complex_dtype = utils ._complex_dtype (dtype )
143+ dtype = np .complex64
152144
153145 k , c = utils .type1_problem (dtype , shape , M )
154146
155147 k_gpu = to_gpu (k )
156148 c_gpu = to_gpu (c )
157- fk_gpu = _compat .array_empty_like (c_gpu , shape , dtype = complex_dtype )
149+ fk_gpu = _compat .array_empty_like (c_gpu , shape , dtype = dtype )
158150
159- plan = Plan (1 , shape , eps = tol , dtype = complex_dtype , gpu_sort = False ,
151+ plan = Plan (1 , shape , eps = tol , dtype = dtype , gpu_sort = False ,
160152 gpu_maxsubprobsize = 10 )
161153
162154 plan .setpts (k_gpu [0 ], k_gpu [1 ], k_gpu [2 ])
0 commit comments