Skip to content

Commit 14049c8

Browse files
Make complex scalars work with numpy 2.0
This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Macros have been added to ensure compatibility with numpy < 2.0 Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior.
1 parent 95fd678 commit 14049c8

File tree

1 file changed

+161
-64
lines changed

1 file changed

+161
-64
lines changed

pytensor/scalar/basic.py

Lines changed: 161 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def c_headers(self, c_compiler=None, **kwargs):
350350
# we declare them here and they will be re-used by TensorType
351351
l.append("<numpy/arrayobject.h>")
352352
l.append("<numpy/arrayscalars.h>")
353+
l.append("<numpy/npy_math.h>")
354+
353355
if config.lib__amdlibm and c_compiler.supports_amdlibm:
354356
l += ["<amdlibm.h>"]
355357
return l
@@ -518,73 +520,167 @@ def c_support_code(self, **kwargs):
518520
# In that case we add the 'int' type to the real types.
519521
real_types.append("int")
520522

523+
# Macros for backwards compatibility with numpy < 2.0
524+
#
525+
# In numpy 2.0+, these are defined in npy_math.h, but
526+
# for early versions, they must be vendored by users (e.g. PyTensor)
527+
backwards_compat_macros = """
528+
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
529+
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
530+
531+
#include <numpy/npy_math.h>
532+
533+
#ifndef NPY_CSETREALF
534+
#define NPY_CSETREALF(c, r) (c)->real = (r)
535+
#endif
536+
#ifndef NPY_CSETIMAGF
537+
#define NPY_CSETIMAGF(c, i) (c)->imag = (i)
538+
#endif
539+
#ifndef NPY_CSETREAL
540+
#define NPY_CSETREAL(c, r) (c)->real = (r)
541+
#endif
542+
#ifndef NPY_CSETIMAG
543+
#define NPY_CSETIMAG(c, i) (c)->imag = (i)
544+
#endif
545+
#ifndef NPY_CSETREALL
546+
#define NPY_CSETREALL(c, r) (c)->real = (r)
547+
#endif
548+
#ifndef NPY_CSETIMAGL
549+
#define NPY_CSETIMAGL(c, i) (c)->imag = (i)
550+
#endif
551+
552+
#endif
553+
"""
554+
555+
def _make_get_set_real_imag(scalar_type: str) -> str:
556+
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
557+
558+
The functions called by these getter/setter functions are defining in npy_math.h, or
559+
in the `backward_compat_macros` defined above.
560+
561+
Args:
562+
scalar_type: float, double, or longdouble
563+
564+
Returns:
565+
C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
566+
given type.
567+
"""
568+
complex_type = "npy_c" + scalar_type
569+
suffix = "" if scalar_type == "double" else scalar_type[0]
570+
571+
if scalar_type == "longdouble":
572+
scalar_type = "npy_" + scalar_type
573+
574+
return_type = scalar_type
575+
576+
template = f"""
577+
static inline {return_type} get_real(const {complex_type} z)
578+
{{
579+
return npy_creal{suffix}(z);
580+
}}
581+
582+
static inline void set_real({complex_type} *z, const {scalar_type} r)
583+
{{
584+
NPY_CSETREAL{suffix.upper()}(z, r);
585+
}}
586+
587+
static inline {return_type} get_imag(const {complex_type} z)
588+
{{
589+
return npy_cimag{suffix}(z);
590+
}}
591+
592+
static inline void set_imag({complex_type} *z, const {scalar_type} i)
593+
{{
594+
NPY_CSETIMAG{suffix.upper()}(z, i);
595+
}}
596+
"""
597+
return template
598+
599+
get_set_aliases = "\n".join(
600+
_make_get_set_real_imag(stype)
601+
for stype in ["float", "double", "longdouble"]
602+
)
603+
604+
get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases
605+
606+
# Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
607+
#
608+
# The npy_complex64, npy_complex128 types are aliases defined at run time based on
609+
# the size of floats and doubles on the machine. This means that both types are
610+
# not necessarily defined on every machine, but a machine with 32-bit floats and
611+
# 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
612+
# as an alias of npy_complex128.
613+
#
614+
# In any case, the get/set real/imag functions defined above will always work for
615+
# npy_complex64 and npy_complex128.
521616
template = """
522-
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
523-
{
524-
typedef pytensor_complex%(nbits)s complex_type;
525-
typedef npy_float%(half_nbits)s scalar_type;
526-
527-
complex_type operator +(const complex_type &y) const {
528-
complex_type ret;
529-
ret.real = this->real + y.real;
530-
ret.imag = this->imag + y.imag;
531-
return ret;
532-
}
533-
534-
complex_type operator -() const {
535-
complex_type ret;
536-
ret.real = -this->real;
537-
ret.imag = -this->imag;
538-
return ret;
539-
}
540-
bool operator ==(const complex_type &y) const {
541-
return (this->real == y.real) && (this->imag == y.imag);
542-
}
543-
bool operator ==(const scalar_type &y) const {
544-
return (this->real == y) && (this->imag == 0);
545-
}
546-
complex_type operator -(const complex_type &y) const {
547-
complex_type ret;
548-
ret.real = this->real - y.real;
549-
ret.imag = this->imag - y.imag;
550-
return ret;
551-
}
552-
complex_type operator *(const complex_type &y) const {
553-
complex_type ret;
554-
ret.real = this->real * y.real - this->imag * y.imag;
555-
ret.imag = this->real * y.imag + this->imag * y.real;
556-
return ret;
557-
}
558-
complex_type operator /(const complex_type &y) const {
559-
complex_type ret;
560-
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
561-
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
562-
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
563-
return ret;
564-
}
565-
template <typename T>
566-
complex_type& operator =(const T& y);
567-
568-
pytensor_complex%(nbits)s() {}
569-
570-
template <typename T>
571-
pytensor_complex%(nbits)s(const T& y) { *this = y; }
572-
573-
template <typename TR, typename TI>
574-
pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
617+
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
618+
typedef pytensor_complex%(nbits)s complex_type;
619+
typedef npy_float%(half_nbits)s scalar_type;
620+
621+
complex_type operator+(const complex_type &y) const {
622+
complex_type ret;
623+
set_real(&ret, get_real(*this) + get_real(y));
624+
set_imag(&ret, get_imag(*this) + get_imag(y));
625+
return ret;
626+
}
627+
628+
complex_type operator-() const {
629+
complex_type ret;
630+
set_real(&ret, -get_real(*this));
631+
set_imag(&ret, -get_imag(*this));
632+
return ret;
633+
}
634+
bool operator==(const complex_type &y) const {
635+
return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
636+
}
637+
bool operator==(const scalar_type &y) const {
638+
return (get_real(*this) == y) && (get_real(*this) == 0);
639+
}
640+
complex_type operator-(const complex_type &y) const {
641+
complex_type ret;
642+
set_real(&ret, get_real(*this) - get_real(y));
643+
set_imag(&ret, get_imag(*this) - get_imag(y));
644+
return ret;
645+
}
646+
complex_type operator*(const complex_type &y) const {
647+
complex_type ret;
648+
set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
649+
set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
650+
return ret;
651+
}
652+
complex_type operator/(const complex_type &y) const {
653+
complex_type ret;
654+
scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
655+
set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
656+
set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
657+
return ret;
658+
}
659+
template <typename T> complex_type &operator=(const T &y);
660+
661+
662+
pytensor_complex%(nbits)s() {}
663+
664+
template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
665+
666+
template <typename TR, typename TI>
667+
pytensor_complex%(nbits)s(const TR &r, const TI &i) {
668+
set_real(this, r);
669+
set_imag(this, i);
670+
}
575671
};
576672
"""
577673

578674
def operator_eq_real(mytype, othertype):
579675
return f"""
580676
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
581-
{{ this->real=y; this->imag=0; return *this; }}
677+
{{ set_real(this, y); set_imag(this, 0); return *this; }}
582678
"""
583679

584680
def operator_eq_cplx(mytype, othertype):
585681
return f"""
586682
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
587-
{{ this->real=y.real; this->imag=y.imag; return *this; }}
683+
{{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }}
588684
"""
589685

590686
operator_eq = "".join(
@@ -606,10 +702,10 @@ def operator_eq_cplx(mytype, othertype):
606702
def operator_plus_real(mytype, othertype):
607703
return f"""
608704
const {mytype} operator+(const {mytype} &x, const {othertype} &y)
609-
{{ return {mytype}(x.real+y, x.imag); }}
705+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
610706
611707
const {mytype} operator+(const {othertype} &y, const {mytype} &x)
612-
{{ return {mytype}(x.real+y, x.imag); }}
708+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
613709
"""
614710

615711
operator_plus = "".join(
@@ -621,10 +717,10 @@ def operator_plus_real(mytype, othertype):
621717
def operator_minus_real(mytype, othertype):
622718
return f"""
623719
const {mytype} operator-(const {mytype} &x, const {othertype} &y)
624-
{{ return {mytype}(x.real-y, x.imag); }}
720+
{{ return {mytype}(get_real(x) - y, get_imag(x)); }}
625721
626722
const {mytype} operator-(const {othertype} &y, const {mytype} &x)
627-
{{ return {mytype}(y-x.real, -x.imag); }}
723+
{{ return {mytype}(y - get_real(x), -get_imag(x)); }}
628724
"""
629725

630726
operator_minus = "".join(
@@ -636,10 +732,10 @@ def operator_minus_real(mytype, othertype):
636732
def operator_mul_real(mytype, othertype):
637733
return f"""
638734
const {mytype} operator*(const {mytype} &x, const {othertype} &y)
639-
{{ return {mytype}(x.real*y, x.imag*y); }}
735+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
640736
641737
const {mytype} operator*(const {othertype} &y, const {mytype} &x)
642-
{{ return {mytype}(x.real*y, x.imag*y); }}
738+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
643739
"""
644740

645741
operator_mul = "".join(
@@ -649,7 +745,8 @@ def operator_mul_real(mytype, othertype):
649745
)
650746

651747
return (
652-
template % dict(nbits=64, half_nbits=32)
748+
get_set_aliases
749+
+ template % dict(nbits=64, half_nbits=32)
653750
+ template % dict(nbits=128, half_nbits=64)
654751
+ operator_eq
655752
+ operator_plus
@@ -664,7 +761,7 @@ def c_init_code(self, **kwargs):
664761
return ["import_array();"]
665762

666763
def c_code_cache_version(self):
667-
return (13, np.__version__)
764+
return (14, np.__version__)
668765

669766
def get_shape_info(self, obj):
670767
return obj.itemsize
@@ -2568,7 +2665,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25682665
if type in float_types:
25692666
return f"{z} = fabs({x});"
25702667
if type in complex_types:
2571-
return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);"
2668+
return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));"
25722669
if node.outputs[0].type == bool:
25732670
return f"{z} = ({x}) ? 1 : 0;"
25742671
if type in uint_types:

0 commit comments

Comments
 (0)