@@ -349,6 +349,8 @@ def c_headers(self, c_compiler=None, **kwargs):
349349 # we declare them here and they will be re-used by TensorType
350350 l .append ("<numpy/arrayobject.h>" )
351351 l .append ("<numpy/arrayscalars.h>" )
352+ l .append ("<numpy/npy_math.h>" )
353+
352354 if config .lib__amdlibm and c_compiler .supports_amdlibm :
353355 l += ["<amdlibm.h>" ]
354356 return l
@@ -517,73 +519,167 @@ def c_support_code(self, **kwargs):
517519 # In that case we add the 'int' type to the real types.
518520 real_types .append ("int" )
519521
522+ # Macros for backwards compatibility with numpy < 2.0
523+ #
524+ # In numpy 2.0+, these are defined in npy_math.h, but
525+ # for early versions, they must be vendored by users (e.g. PyTensor)
526+ backwards_compat_macros = """
527+ #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
528+ #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
529+
530+ #include <numpy/npy_math.h>
531+
532+ #ifndef NPY_CSETREALF
533+ #define NPY_CSETREALF(c, r) (c)->real = (r)
534+ #endif
535+ #ifndef NPY_CSETIMAGF
536+ #define NPY_CSETIMAGF(c, i) (c)->imag = (i)
537+ #endif
538+ #ifndef NPY_CSETREAL
539+ #define NPY_CSETREAL(c, r) (c)->real = (r)
540+ #endif
541+ #ifndef NPY_CSETIMAG
542+ #define NPY_CSETIMAG(c, i) (c)->imag = (i)
543+ #endif
544+ #ifndef NPY_CSETREALL
545+ #define NPY_CSETREALL(c, r) (c)->real = (r)
546+ #endif
547+ #ifndef NPY_CSETIMAGL
548+ #define NPY_CSETIMAGL(c, i) (c)->imag = (i)
549+ #endif
550+
551+ #endif
552+ """
553+
554+ def _make_get_set_real_imag (scalar_type : str ) -> str :
555+ """Make overloaded getter/setter functions for real/imag parts of numpy complex types.
556+
557+ The functions called by these getter/setter functions are defining in npy_math.h, or
558+ in the `backward_compat_macros` defined above.
559+
560+ Args:
561+ scalar_type: float, double, or longdouble
562+
563+ Returns:
564+ C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
565+ given type.
566+ """
567+ complex_type = "npy_c" + scalar_type
568+ suffix = "" if scalar_type == "double" else scalar_type [0 ]
569+
570+ if scalar_type == "longdouble" :
571+ scalar_type = "npy_" + scalar_type
572+
573+ return_type = scalar_type
574+
575+ template = f"""
576+ static inline { return_type } get_real(const { complex_type } z)
577+ {{
578+ return npy_creal{ suffix } (z);
579+ }}
580+
581+ static inline void set_real({ complex_type } *z, const { scalar_type } r)
582+ {{
583+ NPY_CSETREAL{ suffix .upper ()} (z, r);
584+ }}
585+
586+ static inline { return_type } get_imag(const { complex_type } z)
587+ {{
588+ return npy_cimag{ suffix } (z);
589+ }}
590+
591+ static inline void set_imag({ complex_type } *z, const { scalar_type } i)
592+ {{
593+ NPY_CSETIMAG{ suffix .upper ()} (z, i);
594+ }}
595+ """
596+ return template
597+
598+ get_set_aliases = "\n " .join (
599+ _make_get_set_real_imag (stype )
600+ for stype in ["float" , "double" , "longdouble" ]
601+ )
602+
603+ get_set_aliases = backwards_compat_macros + "\n " + get_set_aliases
604+
605+ # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
606+ #
607+ # The npy_complex64, npy_complex128 types are aliases defined at run time based on
608+ # the size of floats and doubles on the machine. This means that both types are
609+ # not necessarily defined on every machine, but a machine with 32-bit floats and
610+ # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
611+ # as an alias of npy_complex128.
612+ #
613+ # In any case, the get/set real/imag functions defined above will always work for
614+ # npy_complex64 and npy_complex128.
520615 template = """
521- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
522- {
523- typedef pytensor_complex%(nbits)s complex_type;
524- typedef npy_float%(half_nbits)s scalar_type;
525-
526- complex_type operator +(const complex_type &y) const {
527- complex_type ret;
528- ret.real = this->real + y.real;
529- ret.imag = this->imag + y.imag;
530- return ret;
531- }
532-
533- complex_type operator -() const {
534- complex_type ret;
535- ret.real = -this->real;
536- ret.imag = -this->imag;
537- return ret;
538- }
539- bool operator ==(const complex_type &y) const {
540- return (this->real == y.real) && (this->imag == y.imag);
541- }
542- bool operator ==(const scalar_type &y) const {
543- return (this->real == y) && (this->imag == 0);
544- }
545- complex_type operator -(const complex_type &y) const {
546- complex_type ret;
547- ret.real = this->real - y.real;
548- ret.imag = this->imag - y.imag;
549- return ret;
550- }
551- complex_type operator *(const complex_type &y) const {
552- complex_type ret;
553- ret.real = this->real * y.real - this->imag * y.imag;
554- ret.imag = this->real * y.imag + this->imag * y.real;
555- return ret;
556- }
557- complex_type operator /(const complex_type &y) const {
558- complex_type ret;
559- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
560- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
561- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
562- return ret;
563- }
564- template <typename T>
565- complex_type& operator =(const T& y);
566-
567- pytensor_complex%(nbits)s() {}
568-
569- template <typename T>
570- pytensor_complex%(nbits)s(const T& y) { *this = y; }
571-
572- template <typename TR, typename TI>
573- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
616+ struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
617+ typedef pytensor_complex%(nbits)s complex_type;
618+ typedef npy_float%(half_nbits)s scalar_type;
619+
620+ complex_type operator+(const complex_type &y) const {
621+ complex_type ret;
622+ set_real(&ret, get_real(*this) + get_real(y));
623+ set_imag(&ret, get_imag(*this) + get_imag(y));
624+ return ret;
625+ }
626+
627+ complex_type operator-() const {
628+ complex_type ret;
629+ set_real(&ret, -get_real(*this));
630+ set_imag(&ret, -get_imag(*this));
631+ return ret;
632+ }
633+ bool operator==(const complex_type &y) const {
634+ return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
635+ }
636+ bool operator==(const scalar_type &y) const {
637+ return (get_real(*this) == y) && (get_real(*this) == 0);
638+ }
639+ complex_type operator-(const complex_type &y) const {
640+ complex_type ret;
641+ set_real(&ret, get_real(*this) - get_real(y));
642+ set_imag(&ret, get_imag(*this) - get_imag(y));
643+ return ret;
644+ }
645+ complex_type operator*(const complex_type &y) const {
646+ complex_type ret;
647+ set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
648+ set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
649+ return ret;
650+ }
651+ complex_type operator/(const complex_type &y) const {
652+ complex_type ret;
653+ scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
654+ set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
655+ set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
656+ return ret;
657+ }
658+ template <typename T> complex_type &operator=(const T &y);
659+
660+
661+ pytensor_complex%(nbits)s() {}
662+
663+ template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
664+
665+ template <typename TR, typename TI>
666+ pytensor_complex%(nbits)s(const TR &r, const TI &i) {
667+ set_real(this, r);
668+ set_imag(this, i);
669+ }
574670 };
575671 """
576672
577673 def operator_eq_real (mytype , othertype ):
578674 return f"""
579675 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
580- {{ this->real=y; this->imag=0 ; return *this; }}
676+ {{ set_real( this, y); set_imag( this, 0) ; return *this; }}
581677 """
582678
583679 def operator_eq_cplx (mytype , othertype ):
584680 return f"""
585681 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
586- {{ this->real=y.real; this->imag=y.imag ; return *this; }}
682+ {{ set_real( this, get_real(y)); set_imag( this, get_imag(y)) ; return *this; }}
587683 """
588684
589685 operator_eq = "" .join (
@@ -605,10 +701,10 @@ def operator_eq_cplx(mytype, othertype):
605701 def operator_plus_real (mytype , othertype ):
606702 return f"""
607703 const { mytype } operator+(const { mytype } &x, const { othertype } &y)
608- {{ return { mytype } (x.real+ y, x.imag ); }}
704+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
609705
610706 const { mytype } operator+(const { othertype } &y, const { mytype } &x)
611- {{ return { mytype } (x.real+ y, x.imag ); }}
707+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
612708 """
613709
614710 operator_plus = "" .join (
@@ -620,10 +716,10 @@ def operator_plus_real(mytype, othertype):
620716 def operator_minus_real (mytype , othertype ):
621717 return f"""
622718 const { mytype } operator-(const { mytype } &x, const { othertype } &y)
623- {{ return { mytype } (x.real- y, x.imag ); }}
719+ {{ return { mytype } (get_real(x) - y, get_imag(x) ); }}
624720
625721 const { mytype } operator-(const { othertype } &y, const { mytype } &x)
626- {{ return { mytype } (y-x.real , -x.imag ); }}
722+ {{ return { mytype } (y - get_real(x) , -get_imag(x) ); }}
627723 """
628724
629725 operator_minus = "" .join (
@@ -635,10 +731,10 @@ def operator_minus_real(mytype, othertype):
635731 def operator_mul_real (mytype , othertype ):
636732 return f"""
637733 const { mytype } operator*(const { mytype } &x, const { othertype } &y)
638- {{ return { mytype } (x.real* y, x.imag* y); }}
734+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
639735
640736 const { mytype } operator*(const { othertype } &y, const { mytype } &x)
641- {{ return { mytype } (x.real* y, x.imag* y); }}
737+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
642738 """
643739
644740 operator_mul = "" .join (
@@ -648,7 +744,8 @@ def operator_mul_real(mytype, othertype):
648744 )
649745
650746 return (
651- template % dict (nbits = 64 , half_nbits = 32 )
747+ get_set_aliases
748+ + template % dict (nbits = 64 , half_nbits = 32 )
652749 + template % dict (nbits = 128 , half_nbits = 64 )
653750 + operator_eq
654751 + operator_plus
@@ -663,7 +760,7 @@ def c_init_code(self, **kwargs):
663760 return ["import_array();" ]
664761
665762 def c_code_cache_version (self ):
666- return (13 , np .__version__ )
763+ return (14 , np .__version__ )
667764
668765 def get_shape_info (self , obj ):
669766 return obj .itemsize
@@ -2567,7 +2664,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25672664 if type in float_types :
25682665 return f"{ z } = fabs({ x } );"
25692666 if type in complex_types :
2570- return f"{ z } = sqrt({ x } .real* { x } .real + { x } .imag* { x } .imag );"
2667+ return f"{ z } = sqrt(get_real( { x } ) * get_real( { x } ) + get_imag( { x } ) * get_imag( { x } ) );"
25712668 if node .outputs [0 ].type == bool :
25722669 return f"{ z } = ({ x } ) ? 1 : 0;"
25732670 if type in uint_types :
0 commit comments