@@ -92,30 +92,28 @@ def test_simple_type2(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg)
9292@pytest .mark .parametrize ("tol" , TOLS )
9393@pytest .mark .parametrize ("output_arg" , OUTPUT_ARGS )
9494def test_cufinufft3_simple (to_gpu , to_cpu , dtype , dim , n_source_pts , n_target_pts , n_trans , tol , output_arg ):
95- complex_dtype = utils ._complex_dtype (dtype )
96-
95+
9796 fun = {1 : cufinufft .nufft1d3 ,
9897 2 : cufinufft .nufft2d3 ,
9998 3 : cufinufft .nufft3d3 }[dim ]
10099
101100 source_pts , source_coefs , target_pts = utils .type3_problem (
102- complex_dtype , dim , n_source_pts , n_target_pts , n_trans
101+ dtype , dim , n_source_pts , n_target_pts , n_trans
103102 )
104103
105104
106- source_pts_gpu = to_gpu (source_pts )
105+ source_pts_gpu = to_gpu (source_pts )
107106 source_coefs_gpu = to_gpu (source_coefs )
108107 target_pts_gpu = to_gpu (target_pts )
109108
110109 if output_arg :
111110 target_coefs_gpu = _compat .array_empty_like (
112- source_coefs_gpu , n_trans + (n_target_pts ,), dtype = complex_dtype )
113-
111+ source_coefs_gpu , n_trans + (n_target_pts ,), dtype = dtype )
112+
114113 fun (* source_pts_gpu , source_coefs_gpu , * target_pts_gpu , out = target_coefs_gpu , eps = tol )
115- else :
114+ else :
116115 target_coefs_gpu = fun (* source_pts_gpu , source_coefs_gpu , * target_pts_gpu , eps = tol )
117116
118117 target_coefs = to_cpu (target_coefs_gpu )
119118
120119 utils .verify_type3 (source_pts , source_coefs , target_pts , target_coefs , tol )
121-
0 commit comments