Skip to content

Commit 165e7e4

Browse files
Make complex scalars work with numpy 2.0
This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior.
1 parent 84144e4 commit 165e7e4

File tree

1 file changed

+117
-64
lines changed

1 file changed

+117
-64
lines changed

pytensor/scalar/basic.py

Lines changed: 117 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def c_headers(self, c_compiler=None, **kwargs):
350350
# we declare them here and they will be re-used by TensorType
351351
l.append("<numpy/arrayobject.h>")
352352
l.append("<numpy/arrayscalars.h>")
353+
l.append("<numpy/npy_math.h>")
354+
353355
if config.lib__amdlibm and c_compiler.supports_amdlibm:
354356
l += ["<amdlibm.h>"]
355357
return l
@@ -518,73 +520,123 @@ def c_support_code(self, **kwargs):
518520
# In that case we add the 'int' type to the real types.
519521
real_types.append("int")
520522

523+
def _make_get_set_real_imag(scalar_type: str) -> str:
524+
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
525+
526+
The functions called by these getter/setter functions are defining in npy_math.h
527+
528+
Args:
529+
scalar_type: float, double, or longdouble
530+
531+
Returns:
532+
C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
533+
given type.
534+
"""
535+
complex_type = "npy_c" + scalar_type
536+
suffix = "" if scalar_type == "double" else scalar_type[0]
537+
return_type = scalar_type
538+
539+
if scalar_type == "longdouble":
540+
scalar_type += "_t"
541+
return_type = "npy_" + return_type
542+
543+
template = f"""
544+
static inline {return_type} get_real(const {complex_type} z)
545+
{{
546+
return npy_creal{suffix}(z);
547+
}}
548+
549+
static inline void set_real({complex_type} *z, const {scalar_type} r)
550+
{{
551+
npy_csetreal{suffix}(z, r);
552+
}}
553+
554+
static inline {return_type} get_imag(const {complex_type} z)
555+
{{
556+
return npy_cimag{suffix}(z);
557+
}}
558+
559+
static inline void set_imag({complex_type} *z, const {scalar_type} i)
560+
{{
561+
npy_csetimag{suffix}(z, i);
562+
}}
563+
"""
564+
return template
565+
566+
# TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
567+
get_set_aliases = "\n".join(
568+
_make_get_set_real_imag(stype)
569+
for stype in ["float", "double", "longdouble"]
570+
)
571+
521572
template = """
522-
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
523-
{
524-
typedef pytensor_complex%(nbits)s complex_type;
525-
typedef npy_float%(half_nbits)s scalar_type;
526-
527-
complex_type operator +(const complex_type &y) const {
528-
complex_type ret;
529-
ret.real = this->real + y.real;
530-
ret.imag = this->imag + y.imag;
531-
return ret;
532-
}
533-
534-
complex_type operator -() const {
535-
complex_type ret;
536-
ret.real = -this->real;
537-
ret.imag = -this->imag;
538-
return ret;
539-
}
540-
bool operator ==(const complex_type &y) const {
541-
return (this->real == y.real) && (this->imag == y.imag);
542-
}
543-
bool operator ==(const scalar_type &y) const {
544-
return (this->real == y) && (this->imag == 0);
545-
}
546-
complex_type operator -(const complex_type &y) const {
547-
complex_type ret;
548-
ret.real = this->real - y.real;
549-
ret.imag = this->imag - y.imag;
550-
return ret;
551-
}
552-
complex_type operator *(const complex_type &y) const {
553-
complex_type ret;
554-
ret.real = this->real * y.real - this->imag * y.imag;
555-
ret.imag = this->real * y.imag + this->imag * y.real;
556-
return ret;
557-
}
558-
complex_type operator /(const complex_type &y) const {
559-
complex_type ret;
560-
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
561-
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
562-
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
563-
return ret;
564-
}
565-
template <typename T>
566-
complex_type& operator =(const T& y);
567-
568-
pytensor_complex%(nbits)s() {}
569-
570-
template <typename T>
571-
pytensor_complex%(nbits)s(const T& y) { *this = y; }
572-
573-
template <typename TR, typename TI>
574-
pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
573+
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
574+
typedef pytensor_complex%(nbits)s complex_type;
575+
typedef npy_float%(half_nbits)s scalar_type;
576+
577+
complex_type operator+(const complex_type &y) const {
578+
complex_type ret;
579+
set_real(&ret, get_real(*this) + get_real(y));
580+
set_imag(&ret, get_imag(*this) + get_imag(y));
581+
return ret;
582+
}
583+
584+
complex_type operator-() const {
585+
complex_type ret;
586+
set_real(&ret, -get_real(*this));
587+
set_imag(&ret, -get_imag(*this));
588+
return ret;
589+
}
590+
bool operator==(const complex_type &y) const {
591+
return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
592+
}
593+
bool operator==(const scalar_type &y) const {
594+
return (get_real(*this) == y) && (get_real(*this) == 0);
595+
}
596+
complex_type operator-(const complex_type &y) const {
597+
complex_type ret;
598+
set_real(&ret, get_real(*this) - get_real(y));
599+
set_imag(&ret, get_imag(*this) - get_imag(y));
600+
return ret;
601+
}
602+
complex_type operator*(const complex_type &y) const {
603+
complex_type ret;
604+
set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
605+
set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
606+
return ret;
607+
}
608+
complex_type operator/(const complex_type &y) const {
609+
complex_type ret;
610+
scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
611+
set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
612+
set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
613+
return ret;
614+
}
615+
template <typename T> complex_type &operator=(const T &y);
616+
617+
618+
pytensor_complex%(nbits)s() {}
619+
620+
template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
621+
622+
template <typename TR, typename TI>
623+
pytensor_complex%(nbits)s(const TR &r, const TI &i) {
624+
set_real(this, r);
625+
set_imag(this, i);
626+
}
575627
};
576628
"""
577629

578630
def operator_eq_real(mytype, othertype):
579631
return f"""
580632
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
581-
{{ this->real=y; this->imag=0; return *this; }}
633+
{{ set_real(this, y); set_imag(this, 0); return *this; }}
582634
"""
583635

584636
def operator_eq_cplx(mytype, othertype):
585637
return f"""
586638
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
587-
{{ this->real=y.real; this->imag=y.imag; return *this; }}
639+
{{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }}
588640
"""
589641

590642
operator_eq = "".join(
@@ -606,10 +658,10 @@ def operator_eq_cplx(mytype, othertype):
606658
def operator_plus_real(mytype, othertype):
607659
return f"""
608660
const {mytype} operator+(const {mytype} &x, const {othertype} &y)
609-
{{ return {mytype}(x.real+y, x.imag); }}
661+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
610662
611663
const {mytype} operator+(const {othertype} &y, const {mytype} &x)
612-
{{ return {mytype}(x.real+y, x.imag); }}
664+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
613665
"""
614666

615667
operator_plus = "".join(
@@ -621,10 +673,10 @@ def operator_plus_real(mytype, othertype):
621673
def operator_minus_real(mytype, othertype):
622674
return f"""
623675
const {mytype} operator-(const {mytype} &x, const {othertype} &y)
624-
{{ return {mytype}(x.real-y, x.imag); }}
676+
{{ return {mytype}(get_real(x) - y, get_imag(x)); }}
625677
626678
const {mytype} operator-(const {othertype} &y, const {mytype} &x)
627-
{{ return {mytype}(y-x.real, -x.imag); }}
679+
{{ return {mytype}(y - get_real(x), -get_imag(x)); }}
628680
"""
629681

630682
operator_minus = "".join(
@@ -636,10 +688,10 @@ def operator_minus_real(mytype, othertype):
636688
def operator_mul_real(mytype, othertype):
637689
return f"""
638690
const {mytype} operator*(const {mytype} &x, const {othertype} &y)
639-
{{ return {mytype}(x.real*y, x.imag*y); }}
691+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
640692
641693
const {mytype} operator*(const {othertype} &y, const {mytype} &x)
642-
{{ return {mytype}(x.real*y, x.imag*y); }}
694+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
643695
"""
644696

645697
operator_mul = "".join(
@@ -649,7 +701,8 @@ def operator_mul_real(mytype, othertype):
649701
)
650702

651703
return (
652-
template % dict(nbits=64, half_nbits=32)
704+
get_set_aliases
705+
+ template % dict(nbits=64, half_nbits=32)
653706
+ template % dict(nbits=128, half_nbits=64)
654707
+ operator_eq
655708
+ operator_plus
@@ -664,7 +717,7 @@ def c_init_code(self, **kwargs):
664717
return ["import_array();"]
665718

666719
def c_code_cache_version(self):
667-
return (13, np.__version__)
720+
return (15, np.version.git_revision)
668721

669722
def get_shape_info(self, obj):
670723
return obj.itemsize
@@ -2568,7 +2621,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25682621
if type in float_types:
25692622
return f"{z} = fabs({x});"
25702623
if type in complex_types:
2571-
return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);"
2624+
return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));"
25722625
if node.outputs[0].type == bool:
25732626
return f"{z} = ({x}) ? 1 : 0;"
25742627
if type in uint_types:

0 commit comments

Comments
 (0)