@@ -350,6 +350,8 @@ def c_headers(self, c_compiler=None, **kwargs):
350350 # we declare them here and they will be re-used by TensorType
351351 l .append ("<numpy/arrayobject.h>" )
352352 l .append ("<numpy/arrayscalars.h>" )
353+ l .append ("<numpy/npy_math.h>" )
354+
353355 if config .lib__amdlibm and c_compiler .supports_amdlibm :
354356 l += ["<amdlibm.h>" ]
355357 return l
@@ -518,73 +520,123 @@ def c_support_code(self, **kwargs):
518520 # In that case we add the 'int' type to the real types.
519521 real_types .append ("int" )
520522
523+ def _make_get_set_real_imag (scalar_type : str ) -> str :
524+ """Make overloaded getter/setter functions for real/imag parts of numpy complex types.
525+
526+ The functions called by these getter/setter functions are defining in npy_math.h
527+
528+ Args:
529+ scalar_type: float, double, or longdouble
530+
531+ Returns:
532+ C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
533+ given type.
534+ """
535+ complex_type = "npy_c" + scalar_type
536+ suffix = "" if scalar_type == "double" else scalar_type [0 ]
537+ return_type = scalar_type
538+
539+ if scalar_type == "longdouble" :
540+ scalar_type += "_t"
541+ return_type = "npy_" + return_type
542+
543+ template = f"""
544+ static inline { return_type } get_real(const { complex_type } z)
545+ {{
546+ return npy_creal{ suffix } (z);
547+ }}
548+
549+ static inline void set_real({ complex_type } *z, const { scalar_type } r)
550+ {{
551+ npy_csetreal{ suffix } (z, r);
552+ }}
553+
554+ static inline { return_type } get_imag(const { complex_type } z)
555+ {{
556+ return npy_cimag{ suffix } (z);
557+ }}
558+
559+ static inline void set_imag({ complex_type } *z, const { scalar_type } i)
560+ {{
561+ npy_csetimag{ suffix } (z, i);
562+ }}
563+ """
564+ return template
565+
566+ # TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
567+ get_set_aliases = "\n " .join (
568+ _make_get_set_real_imag (stype )
569+ for stype in ["float" , "double" , "longdouble" ]
570+ )
571+
521572 template = """
522- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
523- {
524- typedef pytensor_complex%(nbits)s complex_type;
525- typedef npy_float%(half_nbits)s scalar_type;
526-
527- complex_type operator +(const complex_type &y) const {
528- complex_type ret;
529- ret.real = this->real + y.real;
530- ret.imag = this->imag + y.imag;
531- return ret;
532- }
533-
534- complex_type operator -() const {
535- complex_type ret;
536- ret.real = -this->real;
537- ret.imag = -this->imag;
538- return ret;
539- }
540- bool operator ==(const complex_type &y) const {
541- return (this->real == y.real) && (this->imag == y.imag);
542- }
543- bool operator ==(const scalar_type &y) const {
544- return (this->real == y) && (this->imag == 0);
545- }
546- complex_type operator -(const complex_type &y) const {
547- complex_type ret;
548- ret.real = this->real - y.real;
549- ret.imag = this->imag - y.imag;
550- return ret;
551- }
552- complex_type operator *(const complex_type &y) const {
553- complex_type ret;
554- ret.real = this->real * y.real - this->imag * y.imag;
555- ret.imag = this->real * y.imag + this->imag * y.real;
556- return ret;
557- }
558- complex_type operator /(const complex_type &y) const {
559- complex_type ret;
560- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
561- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
562- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
563- return ret;
564- }
565- template <typename T>
566- complex_type& operator =(const T& y);
567-
568- pytensor_complex%(nbits)s() {}
569-
570- template <typename T>
571- pytensor_complex%(nbits)s(const T& y) { *this = y; }
572-
573- template <typename TR, typename TI>
574- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
573+ struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
574+ typedef pytensor_complex%(nbits)s complex_type;
575+ typedef npy_float%(half_nbits)s scalar_type;
576+
577+ complex_type operator+(const complex_type &y) const {
578+ complex_type ret;
579+ set_real(&ret, get_real(*this) + get_real(y));
580+ set_imag(&ret, get_imag(*this) + get_imag(y));
581+ return ret;
582+ }
583+
584+ complex_type operator-() const {
585+ complex_type ret;
586+ set_real(&ret, -get_real(*this));
587+ set_imag(&ret, -get_imag(*this));
588+ return ret;
589+ }
590+ bool operator==(const complex_type &y) const {
591+ return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
592+ }
593+ bool operator==(const scalar_type &y) const {
594+ return (get_real(*this) == y) && (get_real(*this) == 0);
595+ }
596+ complex_type operator-(const complex_type &y) const {
597+ complex_type ret;
598+ set_real(&ret, get_real(*this) - get_real(y));
599+ set_imag(&ret, get_imag(*this) - get_imag(y));
600+ return ret;
601+ }
602+ complex_type operator*(const complex_type &y) const {
603+ complex_type ret;
604+ set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
605+ set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
606+ return ret;
607+ }
608+ complex_type operator/(const complex_type &y) const {
609+ complex_type ret;
610+ scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
611+ set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
612+ set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
613+ return ret;
614+ }
615+ template <typename T> complex_type &operator=(const T &y);
616+
617+
618+ pytensor_complex%(nbits)s() {}
619+
620+ template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
621+
622+ template <typename TR, typename TI>
623+ pytensor_complex%(nbits)s(const TR &r, const TI &i) {
624+ set_real(this, r);
625+ set_imag(this, i);
626+ }
575627 };
576628 """
577629
578630 def operator_eq_real (mytype , othertype ):
579631 return f"""
580632 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
581- {{ this->real=y; this->imag=0 ; return *this; }}
633+ {{ set_real( this, y); set_imag( this, 0) ; return *this; }}
582634 """
583635
584636 def operator_eq_cplx (mytype , othertype ):
585637 return f"""
586638 template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
587- {{ this->real=y.real; this->imag=y.imag ; return *this; }}
639+ {{ set_real( this, get_real(y)); set_imag( this, get_imag(y)) ; return *this; }}
588640 """
589641
590642 operator_eq = "" .join (
@@ -606,10 +658,10 @@ def operator_eq_cplx(mytype, othertype):
606658 def operator_plus_real (mytype , othertype ):
607659 return f"""
608660 const { mytype } operator+(const { mytype } &x, const { othertype } &y)
609- {{ return { mytype } (x.real+ y, x.imag ); }}
661+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
610662
611663 const { mytype } operator+(const { othertype } &y, const { mytype } &x)
612- {{ return { mytype } (x.real+ y, x.imag ); }}
664+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
613665 """
614666
615667 operator_plus = "" .join (
@@ -621,10 +673,10 @@ def operator_plus_real(mytype, othertype):
621673 def operator_minus_real (mytype , othertype ):
622674 return f"""
623675 const { mytype } operator-(const { mytype } &x, const { othertype } &y)
624- {{ return { mytype } (x.real- y, x.imag ); }}
676+ {{ return { mytype } (get_real(x) - y, get_imag(x) ); }}
625677
626678 const { mytype } operator-(const { othertype } &y, const { mytype } &x)
627- {{ return { mytype } (y-x.real , -x.imag ); }}
679+ {{ return { mytype } (y - get_real(x) , -get_imag(x) ); }}
628680 """
629681
630682 operator_minus = "" .join (
@@ -636,10 +688,10 @@ def operator_minus_real(mytype, othertype):
636688 def operator_mul_real (mytype , othertype ):
637689 return f"""
638690 const { mytype } operator*(const { mytype } &x, const { othertype } &y)
639- {{ return { mytype } (x.real* y, x.imag* y); }}
691+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
640692
641693 const { mytype } operator*(const { othertype } &y, const { mytype } &x)
642- {{ return { mytype } (x.real* y, x.imag* y); }}
694+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
643695 """
644696
645697 operator_mul = "" .join (
@@ -649,7 +701,8 @@ def operator_mul_real(mytype, othertype):
649701 )
650702
651703 return (
652- template % dict (nbits = 64 , half_nbits = 32 )
704+ get_set_aliases
705+ + template % dict (nbits = 64 , half_nbits = 32 )
653706 + template % dict (nbits = 128 , half_nbits = 64 )
654707 + operator_eq
655708 + operator_plus
@@ -664,7 +717,7 @@ def c_init_code(self, **kwargs):
664717 return ["import_array();" ]
665718
666719 def c_code_cache_version (self ):
667- return (13 , np .__version__ )
720+ return (15 , np .version . git_revision )
668721
669722 def get_shape_info (self , obj ):
670723 return obj .itemsize
@@ -2568,7 +2621,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25682621 if type in float_types :
25692622 return f"{ z } = fabs({ x } );"
25702623 if type in complex_types :
2571- return f"{ z } = sqrt({ x } .real* { x } .real + { x } .imag* { x } .imag );"
2624+ return f"{ z } = sqrt(get_real( { x } ) * get_real( { x } ) + get_imag( { x } ) * get_imag( { x } ) );"
25722625 if node .outputs [0 ].type == bool :
25732626 return f"{ z } = ({ x } ) ? 1 : 0;"
25742627 if type in uint_types :
0 commit comments