11# cython: language_level=3
22
3- # import both numpy and the Cython declarations for numpy
43import numpy as np
5- cimport numpy as np
6-
7- # if you want to use the Numpy-C-API from Cython
8- np.import_array()
94
105# ----------------------------------------------------------------------------------------------------#
116
@@ -317,32 +312,48 @@ def is_elmn_non_zero(int el, int m, int n, so3_parameters):
317312
318313
319314# forward and inverse for MW and MWSS for complex functions
320- def inverse (np.ndarray[ double complex , ndim = 1 , mode = " c" ] flmn not None , so3_parameters not None ):
321- cdef so3_parameters_t parameters= create_parameter_struct(so3_parameters)
322-
315+ def inverse (double complex[::1] flmn not None , so3_parameters not None ):
316+ cdef so3_parameters_t parameters = create_parameter_struct(so3_parameters)
323317 if so3_parameters.reality:
324- f_length = f_size(so3_parameters)
325- f = np.zeros([f_length,], dtype = float )
326- so3_core_inverse_via_ssht_real(< double * > np.PyArray_DATA(f), < const double complex * > np.PyArray_DATA(flmn), & parameters)
318+ return np.array(_inverse_real(flmn, parameters))
327319 else :
328- f_length = f_size(so3_parameters)
329- f = np.zeros([f_length,], dtype = complex )
330- so3_core_inverse_via_ssht(< double complex * > np.PyArray_DATA(f), < const double complex * > np.PyArray_DATA(flmn), & parameters)
320+ return np.array(_inverse_complex(flmn, parameters))
331321
322+ cdef double [::1 ] _inverse_real(double complex [::1 ] flmn, so3_parameters_t parameters):
323+ cdef int f_length = so3_sampling_f_size(& parameters)
324+ cdef double [::1 ] f = np.zeros([f_length,], dtype = float )
325+ so3_core_inverse_via_ssht_real(& f[0 ], & flmn[0 ], & parameters)
332326 return f
333327
334- def forward (np.ndarray[ double complex , ndim = 1 , mode = " c" ] f not None , so3_parameters not None ):
335- cdef so3_parameters_t parameters= create_parameter_struct(so3_parameters)
328+ cdef double complex [::1 ] _inverse_complex(double complex [::1 ] flmn, so3_parameters_t parameters):
329+ cdef int f_length = so3_sampling_f_size(& parameters)
330+ cdef double complex [::1 ] f = np.zeros([f_length,], dtype = complex )
331+ so3_core_inverse_via_ssht(& f[0 ], & flmn[0 ], & parameters)
332+ return f
336333
337- if so3_parameters.reality:
338- flmn_length = flmn_size(so3_parameters)
339- flmn = np.zeros([flmn_length,], dtype = float )
340- so3_core_forward_via_ssht_real(< double complex * > np.PyArray_DATA(flmn), < const double * > np.PyArray_DATA(f), & parameters)
334+ ctypedef fused real_or_complex:
335+ double
336+ double complex
337+
338+ def forward (real_or_complex[::1] f not None , so3_parameters not None ):
339+ cdef so3_parameters_t parameters = create_parameter_struct(so3_parameters)
340+ if real_or_complex is double :
341+ if not so3_parameters.reality:
342+ raise ValueError (" f has a real data type but reality flag not set" )
343+ return np.array(_forward_real(f, parameters))
341344 else :
342- flmn_length = flmn_size(so3_parameters)
343- flmn = np.zeros([flmn_length,], dtype = complex )
344- so3_core_forward_via_ssht(< double complex * > np.PyArray_DATA(flmn), < const double complex * > np.PyArray_DATA(f), & parameters)
345+ return np.array(_forward_complex(f, parameters))
346+
347+ cdef double complex [::1 ] _forward_real(double [::1 ] f, so3_parameters_t parameters):
348+ cdef int flmn_length = so3_sampling_flmn_size(& parameters)
349+ cdef double complex [::1 ] flmn = np.zeros([flmn_length,], dtype = complex )
350+ so3_core_forward_via_ssht_real(& flmn[0 ], & f[0 ], & parameters)
351+ return flmn
345352
353+ cdef double complex [::1 ] _forward_complex(double complex [::1 ] f, so3_parameters_t parameters):
354+ cdef int flmn_length = so3_sampling_flmn_size(& parameters)
355+ cdef double complex [::1 ] flmn = np.zeros([flmn_length,], dtype = complex )
356+ so3_core_forward_via_ssht(& flmn[0 ], & f[0 ], & parameters)
346357 return flmn
347358
348359# convolution both in real and harmonic space and helper params function
@@ -368,9 +379,9 @@ def get_convolved_parameters(f_so3_parameters, g_so3_parameters):
368379 return SO3Parameters().from_dict(h_parameters)
369380
370381def convolve (
371- np.ndarray[ double complex , ndim = 1 , mode = " c " ] f not None ,
382+ double complex[::1 ] f not None ,
372383 f_parameters ,
373- np.ndarray[ double complex , ndim = 1 , mode = " c " ] g not None ,
384+ double complex[::1 ] g not None ,
374385 g_parameters
375386 ):
376387
@@ -381,22 +392,22 @@ def convolve(
381392 cdef so3_parameters_t h_parameters_struct= create_parameter_struct(h_parameters)
382393
383394 h_length = f_size(h_parameters)
384- h = np.zeros([h_length,], dtype = complex )
395+ cdef double complex [:: 1 ] h = np.zeros([h_length,], dtype = complex )
385396
386397 so3_conv_convolution(
387- < double complex * > np.PyArray_DATA(h) ,
398+ & h[ 0 ] ,
388399 & h_parameters_struct,
389- < const double complex * > np.PyArray_DATA(f) ,
400+ & f[ 0 ] ,
390401 & f_parameters_struct,
391- < const double complex * > np.PyArray_DATA(g) ,
402+ & g[ 0 ] ,
392403 & g_parameters_struct
393404 )
394405 return h, h_parameters
395406
396407def convolve_harmonic (
397- np.ndarray[ double complex , ndim = 1 , mode = " c " ] flmn not None ,
408+ double complex[::1 ] flmn not None ,
398409 f_parameters ,
399- np.ndarray[ double complex , ndim = 1 , mode = " c " ] glmn not None ,
410+ double complex[::1 ] glmn not None ,
400411 g_parameters
401412 ):
402413
@@ -408,34 +419,29 @@ def convolve_harmonic(
408419 cdef so3_parameters_t h_parameters_struct= create_parameter_struct(h_parameters)
409420
410421 hlmn_length = flmn_size(h_parameters)
411- hlmn = np.zeros([hlmn_length,], dtype = complex )
422+ cdef double complex [:: 1 ] hlmn = np.zeros([hlmn_length,], dtype = complex )
412423
413424 so3_conv_convolution(
414- < double complex * > np.PyArray_DATA( hlmn) ,
425+ & hlmn[ 0 ] ,
415426 & h_parameters_struct,
416- < const double complex * > np.PyArray_DATA( flmn) ,
427+ & flmn[ 0 ] ,
417428 & f_parameters_struct,
418- < const double complex * > np.PyArray_DATA( glmn) ,
429+ & glmn[ 0 ] ,
419430 & g_parameters_struct
420431 )
421432 return hlmn, h_parameters
422433
423434def s2toso3_harmonic_convolution (
424435 h_so3_parameters ,
425- np.ndarray[ double complex , ndim = 1 , mode = " c " ] flm not None ,
426- np.ndarray[ double complex , ndim = 1 , mode = " c " ] glm not None ):
436+ double complex[::1 ] flm not None ,
437+ double complex[::1 ] glm not None ):
427438
428439 cdef so3_parameters_t h_parameters= create_parameter_struct(h_so3_parameters)
429440
430441 hlmn_length = flmn_size(h_so3_parameters)
431- hlmn = np.zeros([hlmn_length,], dtype = complex )
442+ cdef double complex [:: 1 ] hlmn = np.zeros([hlmn_length,], dtype = complex )
432443
433- so3_conv_s2toso3_harmonic_convolution(
434- < double complex * > np.PyArray_DATA(hlmn),
435- & h_parameters,
436- < const double complex * > np.PyArray_DATA(flm),
437- < const double complex * > np.PyArray_DATA(glm)
438- )
444+ so3_conv_s2toso3_harmonic_convolution(& hlmn[0 ], & h_parameters, & flm[0 ], & glm[0 ])
439445 return hlmn
440446
441447def test_func ():
0 commit comments