@@ -350,6 +350,8 @@ def c_headers(self, c_compiler=None, **kwargs):
350350 # we declare them here and they will be re-used by TensorType
351351 l .append ("<numpy/arrayobject.h>" )
352352 l .append ("<numpy/arrayscalars.h>" )
353+ l .append ("<numpy/npy_math.h>" )
354+
353355 if config .lib__amdlibm and c_compiler .supports_amdlibm :
354356 l += ["<amdlibm.h>" ]
355357 return l
@@ -518,73 +520,167 @@ def c_support_code(self, **kwargs):
518520 # In that case we add the 'int' type to the real types.
519521 real_types .append ("int" )
520522
523+ # Macros for backwards compatibility with numpy < 2.0
524+ #
525+ # In numpy 2.0+, these are defined in npy_math.h, but
526+ # for early versions, they must be vendored by users (e.g. PyTensor)
527+ backwards_compat_macros = """
528+ #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
529+ #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
530+
531+ #include <numpy/npy_math.h>
532+
533+ #ifndef NPY_CSETREALF
534+ #define NPY_CSETREALF(c, r) (c)->real = (r)
535+ #endif
536+ #ifndef NPY_CSETIMAGF
537+ #define NPY_CSETIMAGF(c, i) (c)->imag = (i)
538+ #endif
539+ #ifndef NPY_CSETREAL
540+ #define NPY_CSETREAL(c, r) (c)->real = (r)
541+ #endif
542+ #ifndef NPY_CSETIMAG
543+ #define NPY_CSETIMAG(c, i) (c)->imag = (i)
544+ #endif
545+ #ifndef NPY_CSETREALL
546+ #define NPY_CSETREALL(c, r) (c)->real = (r)
547+ #endif
548+ #ifndef NPY_CSETIMAGL
549+ #define NPY_CSETIMAGL(c, i) (c)->imag = (i)
550+ #endif
551+
552+ #endif
553+ """
554+
555+ def _make_get_set_real_imag (scalar_type : str ) -> str :
556+ """Make overloaded getter/setter functions for real/imag parts of numpy complex types.
557+
558+ The functions called by these getter/setter functions are defining in npy_math.h, or
559+ in the `backward_compat_macros` defined above.
560+
561+ Args:
562+ scalar_type: float, double, or longdouble
563+
564+ Returns:
565+ C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
566+ given type.
567+ """
568+ complex_type = "npy_c" + scalar_type
569+ suffix = "" if scalar_type == "double" else scalar_type [0 ]
570+
571+ if scalar_type == "longdouble" :
572+ scalar_type = "npy_" + scalar_type
573+
574+ return_type = scalar_type
575+
576+ template = f"""
577+ static inline { return_type } get_real(const { complex_type } z)
578+ {{
579+ return npy_creal{ suffix } (z);
580+ }}
581+
582+ static inline void set_real({ complex_type } *z, const { scalar_type } r)
583+ {{
584+ NPY_CSETREAL{ suffix .upper ()} (z, r);
585+ }}
586+
587+ static inline { return_type } get_imag(const { complex_type } z)
588+ {{
589+ return npy_cimag{ suffix } (z);
590+ }}
591+
592+ static inline void set_imag({ complex_type } *z, const { scalar_type } i)
593+ {{
594+ NPY_CSETIMAG{ suffix .upper ()} (z, i);
595+ }}
596+ """
597+ return template
598+
599+ get_set_aliases = "\n " .join (
600+ _make_get_set_real_imag (stype )
601+ for stype in ["float" , "double" , "longdouble" ]
602+ )
603+
604+ get_set_aliases = backwards_compat_macros + "\n " + get_set_aliases
605+
606+ # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
607+ #
608+ # The npy_complex64, npy_complex128 types are aliases defined at run time based on
609+ # the size of floats and doubles on the machine. This means that both types are
610+ # not necessarily defined on every machine, but a machine with 32-bit floats and
611+ # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
612+ # as an alias of npy_complex128.
613+ #
614+ # In any case, the get/set real/imag functions defined above will always work for
615+ # npy_complex64 and npy_complex128.
521616 template = """
522- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
523- {
524- typedef pytensor_complex%(nbits)s complex_type;
525- typedef npy_float%(half_nbits)s scalar_type;
526-
527- complex_type operator +(const complex_type &y) const {
528- complex_type ret;
529- ret.real = this->real + y.real;
530- ret.imag = this->imag + y.imag;
531- return ret;
532- }
533-
534- complex_type operator -() const {
535- complex_type ret;
536- ret.real = -this->real;
537- ret.imag = -this->imag;
538- return ret;
539- }
540- bool operator ==(const complex_type &y) const {
541- return (this->real == y.real) && (this->imag == y.imag);
542- }
543- bool operator ==(const scalar_type &y) const {
544- return (this->real == y) && (this->imag == 0);
545- }
546- complex_type operator -(const complex_type &y) const {
547- complex_type ret;
548- ret.real = this->real - y.real;
549- ret.imag = this->imag - y.imag;
550- return ret;
551- }
552- complex_type operator *(const complex_type &y) const {
553- complex_type ret;
554- ret.real = this->real * y.real - this->imag * y.imag;
555- ret.imag = this->real * y.imag + this->imag * y.real;
556- return ret;
557- }
558- complex_type operator /(const complex_type &y) const {
559- complex_type ret;
560- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
561- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
562- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
563- return ret;
564- }
565- template <typename T>
566- complex_type& operator =(const T& y);
567-
568- pytensor_complex%(nbits)s() {}
569-
570- template <typename T>
571- pytensor_complex%(nbits)s(const T& y) { *this = y; }
572-
573- template <typename TR, typename TI>
574- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
617+ struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
618+ typedef pytensor_complex%(nbits)s complex_type;
619+ typedef npy_float%(half_nbits)s scalar_type;
620+
621+ complex_type operator+(const complex_type &y) const {
622+ complex_type ret;
623+ set_real(&ret, get_real(*this) + get_real(y));
624+ set_imag(&ret, get_imag(*this) + get_imag(y));
625+ return ret;
626+ }
627+
628+ complex_type operator-() const {
629+ complex_type ret;
630+ set_real(&ret, -get_real(*this));
631+ set_imag(&ret, -get_imag(*this));
632+ return ret;
633+ }
634+ bool operator==(const complex_type &y) const {
635+ return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
636+ }
637+ bool operator==(const scalar_type &y) const {
638+ return (get_real(*this) == y) && (get_real(*this) == 0);
639+ }
640+ complex_type operator-(const complex_type &y) const {
641+ complex_type ret;
642+ set_real(&ret, get_real(*this) - get_real(y));
643+ set_imag(&ret, get_imag(*this) - get_imag(y));
644+ return ret;
645+ }
646+ complex_type operator*(const complex_type &y) const {
647+ complex_type ret;
648+ set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
649+ set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
650+ return ret;
651+ }
652+ complex_type operator/(const complex_type &y) const {
653+ complex_type ret;
654+ scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
655+ set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
656+ set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
657+ return ret;
658+ }
659+ template <typename T> complex_type &operator=(const T &y);
660+
661+
662+ pytensor_complex%(nbits)s() {}
663+
664+ template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
665+
666+ template <typename TR, typename TI>
667+ pytensor_complex%(nbits)s(const TR &r, const TI &i) {
668+ set_real(this, r);
669+ set_imag(this, i);
670+ }
575671 };
576672 """
577673
578674 def operator_eq_real (mytype , othertype ):
579675 return f"""
580676 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
581- {{ this->real=y; this->imag=0 ; return *this; }}
677+ {{ set_real( this, y); set_imag( this, 0) ; return *this; }}
582678 """
583679
584680 def operator_eq_cplx (mytype , othertype ):
585681 return f"""
586682 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
587- {{ this->real=y.real; this->imag=y.imag ; return *this; }}
683+ {{ set_real( this, get_real(y)); set_imag( this, get_imag(y)) ; return *this; }}
588684 """
589685
590686 operator_eq = "" .join (
@@ -606,10 +702,10 @@ def operator_eq_cplx(mytype, othertype):
606702 def operator_plus_real (mytype , othertype ):
607703 return f"""
608704 const { mytype } operator+(const { mytype } &x, const { othertype } &y)
609- {{ return { mytype } (x.real+ y, x.imag ); }}
705+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
610706
611707 const { mytype } operator+(const { othertype } &y, const { mytype } &x)
612- {{ return { mytype } (x.real+ y, x.imag ); }}
708+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
613709 """
614710
615711 operator_plus = "" .join (
@@ -621,10 +717,10 @@ def operator_plus_real(mytype, othertype):
621717 def operator_minus_real (mytype , othertype ):
622718 return f"""
623719 const { mytype } operator-(const { mytype } &x, const { othertype } &y)
624- {{ return { mytype } (x.real- y, x.imag ); }}
720+ {{ return { mytype } (get_real(x) - y, get_imag(x) ); }}
625721
626722 const { mytype } operator-(const { othertype } &y, const { mytype } &x)
627- {{ return { mytype } (y-x.real , -x.imag ); }}
723+ {{ return { mytype } (y - get_real(x) , -get_imag(x) ); }}
628724 """
629725
630726 operator_minus = "" .join (
@@ -636,10 +732,10 @@ def operator_minus_real(mytype, othertype):
636732 def operator_mul_real (mytype , othertype ):
637733 return f"""
638734 const { mytype } operator*(const { mytype } &x, const { othertype } &y)
639- {{ return { mytype } (x.real* y, x.imag* y); }}
735+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
640736
641737 const { mytype } operator*(const { othertype } &y, const { mytype } &x)
642- {{ return { mytype } (x.real* y, x.imag* y); }}
738+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
643739 """
644740
645741 operator_mul = "" .join (
@@ -649,7 +745,8 @@ def operator_mul_real(mytype, othertype):
649745 )
650746
651747 return (
652- template % dict (nbits = 64 , half_nbits = 32 )
748+ get_set_aliases
749+ + template % dict (nbits = 64 , half_nbits = 32 )
653750 + template % dict (nbits = 128 , half_nbits = 64 )
654751 + operator_eq
655752 + operator_plus
@@ -664,7 +761,7 @@ def c_init_code(self, **kwargs):
664761 return ["import_array();" ]
665762
666763 def c_code_cache_version (self ):
667- return (13 , np .__version__ )
764+ return (14 , np .__version__ )
668765
669766 def get_shape_info (self , obj ):
670767 return obj .itemsize
@@ -2568,7 +2665,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25682665 if type in float_types :
25692666 return f"{ z } = fabs({ x } );"
25702667 if type in complex_types :
2571- return f"{ z } = sqrt({ x } .real* { x } .real + { x } .imag* { x } .imag );"
2668+ return f"{ z } = sqrt(get_real( { x } ) * get_real( { x } ) + get_imag( { x } ) * get_imag( { x } ) );"
25722669 if node .outputs [0 ].type == bool :
25732670 return f"{ z } = ({ x } ) ? 1 : 0;"
25742671 if type in uint_types :
0 commit comments